test_jina.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import os
  2. import unittest
  3. from unittest.mock import patch
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.llm.jina import JinaLlm
  6. class TestJinaLlm(unittest.TestCase):
  7. def setUp(self):
  8. os.environ["JINACHAT_API_KEY"] = "test_api_key"
  9. self.config = BaseLlmConfig(
  10. temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt"
  11. )
  12. def test_init_raises_value_error_without_api_key(self):
  13. os.environ.pop("JINACHAT_API_KEY")
  14. with self.assertRaises(ValueError):
  15. JinaLlm()
  16. @patch("embedchain.llm.jina.JinaLlm._get_answer")
  17. def test_get_llm_model_answer(self, mock_get_answer):
  18. mock_get_answer.return_value = "Test answer"
  19. llm = JinaLlm(self.config)
  20. answer = llm.get_llm_model_answer("Test query")
  21. self.assertEqual(answer, "Test answer")
  22. mock_get_answer.assert_called_once()
  23. @patch("embedchain.llm.jina.JinaLlm._get_answer")
  24. def test_get_llm_model_answer_with_system_prompt(self, mock_get_answer):
  25. self.config.system_prompt = "Custom system prompt"
  26. mock_get_answer.return_value = "Test answer"
  27. llm = JinaLlm(self.config)
  28. answer = llm.get_llm_model_answer("Test query")
  29. self.assertEqual(answer, "Test answer")
  30. mock_get_answer.assert_called_once()