test_query.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. from unittest.mock import MagicMock, patch
  3. import pytest
  4. from embedchain import App
  5. from embedchain.config import AppConfig, BaseLlmConfig
  6. from embedchain.llm.openai import OpenAILlm
  7. @pytest.fixture
  8. def app():
  9. os.environ["OPENAI_API_KEY"] = "test_api_key"
  10. app = App(config=AppConfig(collect_metrics=False))
  11. return app
  12. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  13. def test_query(app):
  14. with patch.object(app, "_retrieve_from_database") as mock_retrieve:
  15. mock_retrieve.return_value = ["Test context"]
  16. with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
  17. mock_answer.return_value = "Test answer"
  18. answer = app.query(input_query="Test query")
  19. assert answer == "Test answer"
  20. mock_retrieve.assert_called_once()
  21. _, kwargs = mock_retrieve.call_args
  22. input_query_arg = kwargs.get("input_query")
  23. assert input_query_arg == "Test query"
  24. mock_answer.assert_called_once()
  25. @patch("embedchain.llm.openai.OpenAILlm._get_answer")
  26. def test_query_config_app_passing(mock_get_answer):
  27. mock_get_answer.return_value = MagicMock()
  28. mock_get_answer.return_value = "Test answer"
  29. config = AppConfig(collect_metrics=False)
  30. chat_config = BaseLlmConfig(system_prompt="Test system prompt")
  31. llm = OpenAILlm(config=chat_config)
  32. app = App(config=config, llm=llm)
  33. answer = app.llm.get_llm_model_answer("Test query")
  34. assert app.llm.config.system_prompt == "Test system prompt"
  35. assert answer == "Test answer"
  36. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  37. def test_query_with_where_in_params(app):
  38. with patch.object(app, "_retrieve_from_database") as mock_retrieve:
  39. mock_retrieve.return_value = ["Test context"]
  40. with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
  41. mock_answer.return_value = "Test answer"
  42. answer = app.query("Test query", where={"attribute": "value"})
  43. assert answer == "Test answer"
  44. _, kwargs = mock_retrieve.call_args
  45. assert kwargs.get("input_query") == "Test query"
  46. assert kwargs.get("where") == {"attribute": "value"}
  47. mock_answer.assert_called_once()
  48. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  49. def test_query_with_where_in_query_config(app):
  50. with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
  51. mock_answer.return_value = "Test answer"
  52. with patch.object(app.db, "query") as mock_database_query:
  53. mock_database_query.return_value = ["Test context"]
  54. llm_config = BaseLlmConfig(where={"attribute": "value"})
  55. answer = app.query("Test query", llm_config)
  56. assert answer == "Test answer"
  57. _, kwargs = mock_database_query.call_args
  58. assert kwargs.get("input_query") == "Test query"
  59. where = kwargs.get("where")
  60. assert "app_id" in where
  61. assert "attribute" in where
  62. mock_answer.assert_called_once()