test_base_llm.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import pytest
  2. from embedchain.llm.base import BaseLlm, BaseLlmConfig
  3. @pytest.fixture
  4. def base_llm():
  5. config = BaseLlmConfig()
  6. return BaseLlm(config=config)
  7. def test_is_get_llm_model_answer_not_implemented(base_llm):
  8. with pytest.raises(NotImplementedError):
  9. base_llm.get_llm_model_answer()
  10. def test_is_get_llm_model_answer_implemented():
  11. class TestLlm(BaseLlm):
  12. def get_llm_model_answer(self):
  13. return "Implemented"
  14. config = BaseLlmConfig()
  15. llm = TestLlm(config=config)
  16. assert llm.get_llm_model_answer() == "Implemented"
  17. def test_stream_query_response(base_llm):
  18. answer = ["Chunk1", "Chunk2", "Chunk3"]
  19. result = list(base_llm._stream_query_response(answer))
  20. assert result == answer
  21. def test_stream_chat_response(base_llm):
  22. answer = ["Chunk1", "Chunk2", "Chunk3"]
  23. result = list(base_llm._stream_chat_response(answer))
  24. assert result == answer
  25. def test_append_search_and_context(base_llm):
  26. context = "Context"
  27. web_search_result = "Web Search Result"
  28. result = base_llm._append_search_and_context(context, web_search_result)
  29. expected_result = "Context\nWeb Search Result: Web Search Result"
  30. assert result == expected_result
  31. def test_access_search_and_get_results(base_llm, mocker):
  32. base_llm.access_search_and_get_results = mocker.patch.object(
  33. base_llm, "access_search_and_get_results", return_value="Search Results"
  34. )
  35. input_query = "Test query"
  36. result = base_llm.access_search_and_get_results(input_query)
  37. assert result == "Search Results"