test_base_llm.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from string import Template
  2. import pytest
  3. from embedchain.llm.base import BaseLlm, BaseLlmConfig
  4. @pytest.fixture
  5. def base_llm():
  6. config = BaseLlmConfig()
  7. return BaseLlm(config=config)
  8. def test_is_get_llm_model_answer_not_implemented(base_llm):
  9. with pytest.raises(NotImplementedError):
  10. base_llm.get_llm_model_answer()
  11. def test_is_stream_bool():
  12. with pytest.raises(ValueError):
  13. config = BaseLlmConfig(stream="test value")
  14. BaseLlm(config=config)
  15. def test_template_string_gets_converted_to_Template_instance():
  16. config = BaseLlmConfig(template="test value $query $context")
  17. llm = BaseLlm(config=config)
  18. assert isinstance(llm.config.prompt, Template)
  19. def test_is_get_llm_model_answer_implemented():
  20. class TestLlm(BaseLlm):
  21. def get_llm_model_answer(self):
  22. return "Implemented"
  23. config = BaseLlmConfig()
  24. llm = TestLlm(config=config)
  25. assert llm.get_llm_model_answer() == "Implemented"
  26. def test_stream_response(base_llm):
  27. answer = ["Chunk1", "Chunk2", "Chunk3"]
  28. result = list(base_llm._stream_response(answer))
  29. assert result == answer
  30. def test_append_search_and_context(base_llm):
  31. context = "Context"
  32. web_search_result = "Web Search Result"
  33. result = base_llm._append_search_and_context(context, web_search_result)
  34. expected_result = "Context\nWeb Search Result: Web Search Result"
  35. assert result == expected_result
  36. def test_access_search_and_get_results(base_llm, mocker):
  37. base_llm.access_search_and_get_results = mocker.patch.object(
  38. base_llm, "access_search_and_get_results", return_value="Search Results"
  39. )
  40. input_query = "Test query"
  41. result = base_llm.access_search_and_get_results(input_query)
  42. assert result == "Search Results"