12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- import os
- from unittest.mock import MagicMock, patch
- import pytest
- from embedchain import App
- from embedchain.config import AppConfig, BaseLlmConfig
- from embedchain.llm.openai import OpenAILlm
- @pytest.fixture
- def app():
- os.environ["OPENAI_API_KEY"] = "test_api_key"
- app = App(config=AppConfig(collect_metrics=False))
- return app
- @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
- def test_query(app):
- with patch.object(app, "_retrieve_from_database") as mock_retrieve:
- mock_retrieve.return_value = ["Test context"]
- with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
- mock_answer.return_value = "Test answer"
- answer = app.query(input_query="Test query")
- assert answer == "Test answer"
- mock_retrieve.assert_called_once()
- _, kwargs = mock_retrieve.call_args
- input_query_arg = kwargs.get("input_query")
- assert input_query_arg == "Test query"
- mock_answer.assert_called_once()
- @patch("embedchain.llm.openai.OpenAILlm._get_answer")
- def test_query_config_app_passing(mock_get_answer):
- mock_get_answer.return_value = MagicMock()
- mock_get_answer.return_value = "Test answer"
- config = AppConfig(collect_metrics=False)
- chat_config = BaseLlmConfig(system_prompt="Test system prompt")
- llm = OpenAILlm(config=chat_config)
- app = App(config=config, llm=llm)
- answer = app.llm.get_llm_model_answer("Test query")
- assert app.llm.config.system_prompt == "Test system prompt"
- assert answer == "Test answer"
- @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
- def test_query_with_where_in_params(app):
- with patch.object(app, "_retrieve_from_database") as mock_retrieve:
- mock_retrieve.return_value = ["Test context"]
- with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
- mock_answer.return_value = "Test answer"
- answer = app.query("Test query", where={"attribute": "value"})
- assert answer == "Test answer"
- _, kwargs = mock_retrieve.call_args
- assert kwargs.get("input_query") == "Test query"
- assert kwargs.get("where") == {"attribute": "value"}
- mock_answer.assert_called_once()
- @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
- def test_query_with_where_in_query_config(app):
- with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
- mock_answer.return_value = "Test answer"
- with patch.object(app.db, "query") as mock_database_query:
- mock_database_query.return_value = ["Test context"]
- llm_config = BaseLlmConfig(where={"attribute": "value"})
- answer = app.query("Test query", llm_config)
- assert answer == "Test answer"
- _, kwargs = mock_database_query.call_args
- assert kwargs.get("input_query") == "Test query"
- where = kwargs.get("where")
- assert "app_id" in where
- assert "attribute" in where
- mock_answer.assert_called_once()
|