test_hugging_face_hub.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import importlib
  2. import os
  3. import unittest
  4. from unittest.mock import patch, MagicMock
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.llm.hugging_face_hub import HuggingFaceHubLlm
  7. class TestHuggingFaceHubLlm(unittest.TestCase):
  8. def setUp(self):
  9. os.environ["HUGGINGFACEHUB_ACCESS_TOKEN"] = "test_access_token"
  10. self.config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
  11. def test_init_raises_value_error_without_api_key(self):
  12. os.environ.pop("HUGGINGFACEHUB_ACCESS_TOKEN")
  13. with self.assertRaises(ValueError):
  14. HuggingFaceHubLlm()
  15. def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
  16. llm = HuggingFaceHubLlm(self.config)
  17. llm.config.system_prompt = "system_prompt"
  18. with self.assertRaises(ValueError):
  19. llm.get_llm_model_answer("prompt")
  20. def test_top_p_value_within_range(self):
  21. config = BaseLlmConfig(top_p=1.0)
  22. with self.assertRaises(ValueError):
  23. HuggingFaceHubLlm._get_answer("test_prompt", config)
  24. def test_dependency_is_imported(self):
  25. importlib_installed = True
  26. try:
  27. importlib.import_module("huggingface_hub")
  28. except ImportError:
  29. importlib_installed = False
  30. self.assertTrue(importlib_installed)
  31. @patch("embedchain.llm.hugging_face_hub.HuggingFaceHubLlm._get_answer")
  32. def test_get_llm_model_answer(self, mock_get_answer):
  33. mock_get_answer.return_value = "Test answer"
  34. llm = HuggingFaceHubLlm(self.config)
  35. answer = llm.get_llm_model_answer("Test query")
  36. self.assertEqual(answer, "Test answer")
  37. mock_get_answer.assert_called_once()
  38. @patch("embedchain.llm.hugging_face_hub.HuggingFaceHub")
  39. def test_hugging_face_mock(self, mock_hugging_face_hub):
  40. mock_llm_instance = MagicMock()
  41. mock_llm_instance.return_value = "Test answer"
  42. mock_hugging_face_hub.return_value = mock_llm_instance
  43. llm = HuggingFaceHubLlm(self.config)
  44. answer = llm.get_llm_model_answer("Test query")
  45. self.assertEqual(answer, "Test answer")
  46. mock_hugging_face_hub.assert_called_once_with(
  47. huggingfacehub_api_token="test_access_token",
  48. repo_id="google/flan-t5-xxl",
  49. model_kwargs={"temperature": 0.7, "max_new_tokens": 50, "top_p": 0.8},
  50. )
  51. mock_llm_instance.assert_called_once_with("Test query")