test_query.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import os
  2. from unittest.mock import MagicMock, patch
  3. import pytest
  4. from embedchain import App
  5. from embedchain.config import AppConfig, BaseLlmConfig
  6. @pytest.fixture
  7. def app():
  8. os.environ["OPENAI_API_KEY"] = "test_api_key"
  9. app = App(config=AppConfig(collect_metrics=False))
  10. return app
  11. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  12. def test_query(app):
  13. with patch.object(app, "_retrieve_from_database") as mock_retrieve:
  14. mock_retrieve.return_value = ["Test context"]
  15. with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
  16. mock_answer.return_value = "Test answer"
  17. answer = app.query(input_query="Test query")
  18. assert answer == "Test answer"
  19. mock_retrieve.assert_called_once()
  20. _, kwargs = mock_retrieve.call_args
  21. input_query_arg = kwargs.get("input_query")
  22. assert input_query_arg == "Test query"
  23. mock_answer.assert_called_once()
  24. @patch("embedchain.llm.openai.OpenAILlm._get_answer")
  25. def test_query_config_app_passing(mock_get_answer):
  26. mock_get_answer.return_value = MagicMock()
  27. mock_get_answer.return_value = "Test answer"
  28. config = AppConfig(collect_metrics=False)
  29. chat_config = BaseLlmConfig(system_prompt="Test system prompt")
  30. app = App(config=config, llm_config=chat_config)
  31. answer = app.llm.get_llm_model_answer("Test query")
  32. assert app.llm.config.system_prompt == "Test system prompt"
  33. assert answer == "Test answer"
  34. @patch("embedchain.llm.openai.OpenAILlm._get_answer")
  35. def test_app_passing(mock_get_answer):
  36. mock_get_answer.return_value = MagicMock()
  37. mock_get_answer.return_value = "Test answer"
  38. config = AppConfig(collect_metrics=False)
  39. chat_config = BaseLlmConfig()
  40. app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
  41. answer = app.llm.get_llm_model_answer("Test query")
  42. assert app.llm.config.system_prompt == "Test system prompt"
  43. assert answer == "Test answer"
  44. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  45. def test_query_with_where_in_params(app):
  46. with patch.object(app, "_retrieve_from_database") as mock_retrieve:
  47. mock_retrieve.return_value = ["Test context"]
  48. with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
  49. mock_answer.return_value = "Test answer"
  50. answer = app.query("Test query", where={"attribute": "value"})
  51. assert answer == "Test answer"
  52. _, kwargs = mock_retrieve.call_args
  53. assert kwargs.get("input_query") == "Test query"
  54. assert kwargs.get("where") == {"attribute": "value"}
  55. mock_answer.assert_called_once()
  56. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  57. def test_query_with_where_in_query_config(app):
  58. with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
  59. mock_answer.return_value = "Test answer"
  60. with patch.object(app.db, "query") as mock_database_query:
  61. mock_database_query.return_value = ["Test context"]
  62. llm_config = BaseLlmConfig(where={"attribute": "value"})
  63. answer = app.query("Test query", llm_config)
  64. assert answer == "Test answer"
  65. _, kwargs = mock_database_query.call_args
  66. assert kwargs.get("input_query") == "Test query"
  67. assert kwargs.get("where") == {"attribute": "value"}
  68. mock_answer.assert_called_once()