test_jina.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import pytest
  3. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.llm.jina import JinaLlm
  6. @pytest.fixture
  7. def config():
  8. os.environ["JINACHAT_API_KEY"] = "test_api_key"
  9. config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
  10. yield config
  11. os.environ.pop("JINACHAT_API_KEY")
  12. def test_init_raises_value_error_without_api_key(mocker):
  13. mocker.patch.dict(os.environ, clear=True)
  14. with pytest.raises(ValueError):
  15. JinaLlm()
  16. def test_get_llm_model_answer(config, mocker):
  17. mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
  18. llm = JinaLlm(config)
  19. answer = llm.get_llm_model_answer("Test query")
  20. assert answer == "Test answer"
  21. mocked_get_answer.assert_called_once_with("Test query", config)
  22. def test_get_llm_model_answer_with_system_prompt(config, mocker):
  23. config.system_prompt = "Custom system prompt"
  24. mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
  25. llm = JinaLlm(config)
  26. answer = llm.get_llm_model_answer("Test query")
  27. assert answer == "Test answer"
  28. mocked_get_answer.assert_called_once_with("Test query", config)
  29. def test_get_llm_model_answer_empty_prompt(config, mocker):
  30. mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
  31. llm = JinaLlm(config)
  32. answer = llm.get_llm_model_answer("")
  33. assert answer == "Test answer"
  34. mocked_get_answer.assert_called_once_with("", config)
  35. def test_get_llm_model_answer_with_streaming(config, mocker):
  36. config.stream = True
  37. mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
  38. llm = JinaLlm(config)
  39. llm.get_llm_model_answer("Test query")
  40. mocked_jinachat.assert_called_once()
  41. callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
  42. assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
  43. def test_get_llm_model_answer_without_system_prompt(config, mocker):
  44. config.system_prompt = None
  45. mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
  46. llm = JinaLlm(config)
  47. llm.get_llm_model_answer("Test query")
  48. mocked_jinachat.assert_called_once_with(
  49. temperature=config.temperature,
  50. max_tokens=config.max_tokens,
  51. model_kwargs={"top_p": config.top_p},
  52. )