123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import os
- from unittest.mock import MagicMock, patch
- import pytest
- from embedchain import App
- from embedchain.config import AppConfig, BaseLlmConfig
- @pytest.fixture
- def app():
- os.environ["OPENAI_API_KEY"] = "test_api_key"
- app = App(config=AppConfig(collect_metrics=False))
- return app
- @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
- def test_query(app):
- with patch.object(app, "_retrieve_from_database") as mock_retrieve:
- mock_retrieve.return_value = ["Test context"]
- with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
- mock_answer.return_value = "Test answer"
- answer = app.query(input_query="Test query")
- assert answer == "Test answer"
- mock_retrieve.assert_called_once()
- _, kwargs = mock_retrieve.call_args
- input_query_arg = kwargs.get("input_query")
- assert input_query_arg == "Test query"
- mock_answer.assert_called_once()
- @patch("embedchain.llm.openai.OpenAILlm._get_answer")
- def test_query_config_app_passing(mock_get_answer):
- mock_get_answer.return_value = MagicMock()
- mock_get_answer.return_value = "Test answer"
- config = AppConfig(collect_metrics=False)
- chat_config = BaseLlmConfig(system_prompt="Test system prompt")
- app = App(config=config, llm_config=chat_config)
- answer = app.llm.get_llm_model_answer("Test query")
- assert app.llm.config.system_prompt == "Test system prompt"
- assert answer == "Test answer"
- @patch("embedchain.llm.openai.OpenAILlm._get_answer")
- def test_app_passing(mock_get_answer):
- mock_get_answer.return_value = MagicMock()
- mock_get_answer.return_value = "Test answer"
- config = AppConfig(collect_metrics=False)
- chat_config = BaseLlmConfig()
- app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
- answer = app.llm.get_llm_model_answer("Test query")
- assert app.llm.config.system_prompt == "Test system prompt"
- assert answer == "Test answer"
- @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
- def test_query_with_where_in_params(app):
- with patch.object(app, "_retrieve_from_database") as mock_retrieve:
- mock_retrieve.return_value = ["Test context"]
- with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
- mock_answer.return_value = "Test answer"
- answer = app.query("Test query", where={"attribute": "value"})
- assert answer == "Test answer"
- _, kwargs = mock_retrieve.call_args
- assert kwargs.get("input_query") == "Test query"
- assert kwargs.get("where") == {"attribute": "value"}
- mock_answer.assert_called_once()
- @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
- def test_query_with_where_in_query_config(app):
- with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
- mock_answer.return_value = "Test answer"
- with patch.object(app.db, "query") as mock_database_query:
- mock_database_query.return_value = ["Test context"]
- llm_config = BaseLlmConfig(where={"attribute": "value"})
- answer = app.query("Test query", llm_config)
- assert answer == "Test answer"
- _, kwargs = mock_database_query.call_args
- assert kwargs.get("input_query") == "Test query"
- assert kwargs.get("where") == {"attribute": "value"}
- mock_answer.assert_called_once()
|