瀏覽代碼

Improve and add more tests (#807)

Sidharth Mohanty 1 年之前
父節點
當前提交
e8a2846449

+ 4 - 9
embedchain/llm/openai.py

@@ -13,13 +13,9 @@ class OpenAILlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
         super().__init__(config=config)
 
-    def get_llm_model_answer(self, prompt):
+    def get_llm_model_answer(self, prompt) -> str:
         response = OpenAILlm._get_answer(prompt, self.config)
-
-        if self.config.stream:
-            return response
-        else:
-            return response.content
+        return response
 
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
         messages = []
@@ -35,10 +31,9 @@ class OpenAILlm(BaseLlm):
         if config.top_p:
             kwargs["model_kwargs"]["top_p"] = config.top_p
         if config.stream:
-            from langchain.callbacks.streaming_stdout import \
-                StreamingStdOutCallbackHandler
+            from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
             chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
         else:
             chat = ChatOpenAI(**kwargs)
-        return chat(messages)
+        return chat(messages).content

+ 1 - 0
pyproject.toml

@@ -133,6 +133,7 @@ isort = "^5.12.0"
 pytest-cov = "^4.1.0"
 responses = "^0.23.3"
 mock = "^5.1.0"
+pytest-asyncio = "^0.21.1"
 
 [tool.poetry.extras]
 streamlit = ["streamlit"]

+ 13 - 0
tests/apps/test_apps.py

@@ -104,6 +104,19 @@ class TestConfigForAppComponents:
 
         assert isinstance(embedder_config, BaseEmbedderConfig)
 
+    def test_components_raises_type_error_if_not_proper_instances(self):
+        wrong_llm = "wrong_llm"
+        with pytest.raises(TypeError):
+            App(llm=wrong_llm)
+
+        wrong_db = "wrong_db"
+        with pytest.raises(TypeError):
+            App(db=wrong_db)
+
+        wrong_embedder = "wrong_embedder"
+        with pytest.raises(TypeError):
+            App(embedder=wrong_embedder)
+
 
 class TestAppFromConfig:
     def load_config_data(self, yaml_path):

+ 13 - 1
tests/llm/test_base_llm.py

@@ -1,5 +1,5 @@
 import pytest
-
+from string import Template
 from embedchain.llm.base import BaseLlm, BaseLlmConfig
 
 
@@ -14,6 +14,18 @@ def test_is_get_llm_model_answer_not_implemented(base_llm):
         base_llm.get_llm_model_answer()
 
 
+def test_is_stream_bool():
+    with pytest.raises(ValueError):
+        config = BaseLlmConfig(stream="test value")
+        BaseLlm(config=config)
+
+
+def test_template_string_gets_converted_to_Template_instance():
+    config = BaseLlmConfig(template="test value $query $context")
+    llm = BaseLlm(config=config)
+    assert isinstance(llm.config.template, Template)
+
+
 def test_is_get_llm_model_answer_implemented():
     class TestLlm(BaseLlm):
         def get_llm_model_answer(self):

+ 44 - 22
tests/llm/test_cohere.py

@@ -1,33 +1,55 @@
 import os
-import unittest
-from unittest.mock import patch
+import pytest
 
 from embedchain.config import BaseLlmConfig
 from embedchain.llm.cohere import CohereLlm
 
 
-class TestCohereLlm(unittest.TestCase):
-    def setUp(self):
-        os.environ["COHERE_API_KEY"] = "test_api_key"
-        self.config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
+@pytest.fixture
+def cohere_llm_config():
+    os.environ["COHERE_API_KEY"] = "test_api_key"
+    config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
+    yield config
+    os.environ.pop("COHERE_API_KEY")
 
-    def test_init_raises_value_error_without_api_key(self):
-        os.environ.pop("COHERE_API_KEY")
-        with self.assertRaises(ValueError):
-            CohereLlm()
 
-    def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
-        llm = CohereLlm(self.config)
-        llm.config.system_prompt = "system_prompt"
-        with self.assertRaises(ValueError):
-            llm.get_llm_model_answer("prompt")
+def test_init_raises_value_error_without_api_key(mocker):
+    mocker.patch.dict(os.environ, clear=True)
+    with pytest.raises(ValueError):
+        CohereLlm()
 
-    @patch("embedchain.llm.cohere.CohereLlm._get_answer")
-    def test_get_llm_model_answer(self, mock_get_answer):
-        mock_get_answer.return_value = "Test answer"
 
-        llm = CohereLlm(self.config)
-        answer = llm.get_llm_model_answer("Test query")
+def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config):
+    llm = CohereLlm(cohere_llm_config)
+    llm.config.system_prompt = "system_prompt"
+    with pytest.raises(ValueError):
+        llm.get_llm_model_answer("prompt")
 
-        self.assertEqual(answer, "Test answer")
-        mock_get_answer.assert_called_once()
+
+def test_get_llm_model_answer(cohere_llm_config, mocker):
+    mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer")
+
+    llm = CohereLlm(cohere_llm_config)
+    answer = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+
+
+def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
+    mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere")
+    mock_instance = mocked_cohere.return_value
+    mock_instance.return_value = "Mocked answer"
+
+    llm = CohereLlm(cohere_llm_config)
+    prompt = "Test query"
+    answer = llm.get_llm_model_answer(prompt)
+
+    assert answer == "Mocked answer"
+    mocked_cohere.assert_called_once_with(
+        cohere_api_key="test_api_key",
+        model="gptd-instruct-tft",
+        max_tokens=50,
+        temperature=0.7,
+        p=0.8,
+    )
+    mock_instance.assert_called_once_with(prompt)

+ 55 - 58
tests/llm/test_huggingface.py

@@ -1,64 +1,61 @@
 import importlib
 import os
-import unittest
-from unittest.mock import MagicMock, patch
-
+import pytest
 from embedchain.config import BaseLlmConfig
 from embedchain.llm.huggingface import HuggingFaceLlm
 
 
-class TestHuggingFaceLlm(unittest.TestCase):
-    def setUp(self):
-        os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
-        self.config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
-
-    def test_init_raises_value_error_without_api_key(self):
-        os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
-        with self.assertRaises(ValueError):
-            HuggingFaceLlm()
-
-    def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
-        llm = HuggingFaceLlm(self.config)
-        llm.config.system_prompt = "system_prompt"
-        with self.assertRaises(ValueError):
-            llm.get_llm_model_answer("prompt")
-
-    def test_top_p_value_within_range(self):
-        config = BaseLlmConfig(top_p=1.0)
-        with self.assertRaises(ValueError):
-            HuggingFaceLlm._get_answer("test_prompt", config)
-
-    def test_dependency_is_imported(self):
-        importlib_installed = True
-        try:
-            importlib.import_module("huggingface_hub")
-        except ImportError:
-            importlib_installed = False
-        self.assertTrue(importlib_installed)
-
-    @patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer")
-    def test_get_llm_model_answer(self, mock_get_answer):
-        mock_get_answer.return_value = "Test answer"
-
-        llm = HuggingFaceLlm(self.config)
-        answer = llm.get_llm_model_answer("Test query")
-
-        self.assertEqual(answer, "Test answer")
-        mock_get_answer.assert_called_once()
-
-    @patch("embedchain.llm.huggingface.HuggingFaceHub")
-    def test_hugging_face_mock(self, mock_huggingface):
-        mock_llm_instance = MagicMock()
-        mock_llm_instance.return_value = "Test answer"
-        mock_huggingface.return_value = mock_llm_instance
-
-        llm = HuggingFaceLlm(self.config)
-        answer = llm.get_llm_model_answer("Test query")
-
-        self.assertEqual(answer, "Test answer")
-        mock_huggingface.assert_called_once_with(
-            huggingfacehub_api_token="test_access_token",
-            repo_id="google/flan-t5-xxl",
-            model_kwargs={"temperature": 0.7, "max_new_tokens": 50, "top_p": 0.8},
-        )
-        mock_llm_instance.assert_called_once_with("Test query")
+@pytest.fixture
+def huggingface_llm_config():
+    os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
+    config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
+    yield config
+    os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
+
+
+def test_init_raises_value_error_without_api_key(mocker):
+    mocker.patch.dict(os.environ, clear=True)
+    with pytest.raises(ValueError):
+        HuggingFaceLlm()
+
+
+def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config):
+    llm = HuggingFaceLlm(huggingface_llm_config)
+    llm.config.system_prompt = "system_prompt"
+    with pytest.raises(ValueError):
+        llm.get_llm_model_answer("prompt")
+
+
+def test_top_p_value_within_range():
+    config = BaseLlmConfig(top_p=1.0)
+    with pytest.raises(ValueError):
+        HuggingFaceLlm._get_answer("test_prompt", config)
+
+
+def test_dependency_is_imported():
+    importlib_installed = True
+    try:
+        importlib.import_module("huggingface_hub")
+    except ImportError:
+        importlib_installed = False
+    assert importlib_installed
+
+
+def test_get_llm_model_answer(huggingface_llm_config, mocker):
+    mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer")
+
+    llm = HuggingFaceLlm(huggingface_llm_config)
+    answer = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+
+
+def test_hugging_face_mock(huggingface_llm_config, mocker):
+    mock_llm_instance = mocker.Mock(return_value="Test answer")
+    mocker.patch("embedchain.llm.huggingface.HuggingFaceHub", return_value=mock_llm_instance)
+
+    llm = HuggingFaceLlm(huggingface_llm_config)
+    answer = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    mock_llm_instance.assert_called_once_with("Test query")

+ 64 - 28
tests/llm/test_jina.py

@@ -1,40 +1,76 @@
 import os
-import unittest
-from unittest.mock import patch
-
+import pytest
 from embedchain.config import BaseLlmConfig
 from embedchain.llm.jina import JinaLlm
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+
+
+@pytest.fixture
+def config():
+    os.environ["JINACHAT_API_KEY"] = "test_api_key"
+    config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
+    yield config
+    os.environ.pop("JINACHAT_API_KEY")
+
+
+def test_init_raises_value_error_without_api_key(mocker):
+    mocker.patch.dict(os.environ, clear=True)
+    with pytest.raises(ValueError):
+        JinaLlm()
+
+
+def test_get_llm_model_answer(config, mocker):
+    mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
+
+    llm = JinaLlm(config)
+    answer = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    mocked_get_answer.assert_called_once_with("Test query", config)
+
+
+def test_get_llm_model_answer_with_system_prompt(config, mocker):
+    config.system_prompt = "Custom system prompt"
+    mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
+
+    llm = JinaLlm(config)
+    answer = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    mocked_get_answer.assert_called_once_with("Test query", config)
+
+
+def test_get_llm_model_answer_empty_prompt(config, mocker):
+    mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
+
+    llm = JinaLlm(config)
+    answer = llm.get_llm_model_answer("")
 
+    assert answer == "Test answer"
+    mocked_get_answer.assert_called_once_with("", config)
 
-class TestJinaLlm(unittest.TestCase):
-    def setUp(self):
-        os.environ["JINACHAT_API_KEY"] = "test_api_key"
-        self.config = BaseLlmConfig(
-            temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt"
-        )
 
-    def test_init_raises_value_error_without_api_key(self):
-        os.environ.pop("JINACHAT_API_KEY")
-        with self.assertRaises(ValueError):
-            JinaLlm()
+def test_get_llm_model_answer_with_streaming(config, mocker):
+    config.stream = True
+    mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
 
-    @patch("embedchain.llm.jina.JinaLlm._get_answer")
-    def test_get_llm_model_answer(self, mock_get_answer):
-        mock_get_answer.return_value = "Test answer"
+    llm = JinaLlm(config)
+    llm.get_llm_model_answer("Test query")
 
-        llm = JinaLlm(self.config)
-        answer = llm.get_llm_model_answer("Test query")
+    mocked_jinachat.assert_called_once()
+    callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
+    assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
 
-        self.assertEqual(answer, "Test answer")
-        mock_get_answer.assert_called_once()
 
-    @patch("embedchain.llm.jina.JinaLlm._get_answer")
-    def test_get_llm_model_answer_with_system_prompt(self, mock_get_answer):
-        self.config.system_prompt = "Custom system prompt"
-        mock_get_answer.return_value = "Test answer"
+def test_get_llm_model_answer_without_system_prompt(config, mocker):
+    config.system_prompt = None
+    mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
 
-        llm = JinaLlm(self.config)
-        answer = llm.get_llm_model_answer("Test query")
+    llm = JinaLlm(config)
+    llm.get_llm_model_answer("Test query")
 
-        self.assertEqual(answer, "Test answer")
-        mock_get_answer.assert_called_once()
+    mocked_jinachat.assert_called_once_with(
+        temperature=config.temperature,
+        max_tokens=config.max_tokens,
+        model_kwargs={"top_p": config.top_p},
+    )

+ 47 - 0
tests/llm/test_llama2.py

@@ -0,0 +1,47 @@
+import os
+import pytest
+from embedchain.llm.llama2 import Llama2Llm
+
+
+@pytest.fixture
+def llama2_llm():
+    os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
+    llm = Llama2Llm()
+    return llm
+
+
+def test_init_raises_value_error_without_api_key(mocker):
+    mocker.patch.dict(os.environ, clear=True)
+    with pytest.raises(ValueError):
+        Llama2Llm()
+
+
+def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
+    llama2_llm.config.system_prompt = "system_prompt"
+    with pytest.raises(ValueError):
+        llama2_llm.get_llm_model_answer("prompt")
+
+
+def test_get_llm_model_answer(llama2_llm, mocker):
+    mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
+    mocked_replicate_instance = mocker.MagicMock()
+    mocked_replicate.return_value = mocked_replicate_instance
+    mocked_replicate_instance.return_value = "Test answer"
+
+    llama2_llm.config.model = "test_model"
+    llama2_llm.config.max_tokens = 50
+    llama2_llm.config.temperature = 0.7
+    llama2_llm.config.top_p = 0.8
+
+    answer = llama2_llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    mocked_replicate.assert_called_once_with(
+        model="test_model",
+        input={
+            "temperature": 0.7,
+            "max_length": 50,
+            "top_p": 0.8,
+        },
+    )
+    mocked_replicate_instance.assert_called_once_with("Test query")

+ 73 - 0
tests/llm/test_openai.py

@@ -0,0 +1,73 @@
+import os
+import pytest
+from embedchain.config import BaseLlmConfig
+from embedchain.llm.openai import OpenAILlm
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+
+
+@pytest.fixture
+def config():
+    os.environ["OPENAI_API_KEY"] = "test_api_key"
+    config = BaseLlmConfig(
+        temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
+    )
+    yield config
+    os.environ.pop("OPENAI_API_KEY")
+
+
+def test_get_llm_model_answer(config, mocker):
+    mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
+
+    llm = OpenAILlm(config)
+    answer = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    mocked_get_answer.assert_called_once_with("Test query", config)
+
+
+def test_get_llm_model_answer_with_system_prompt(config, mocker):
+    config.system_prompt = "Custom system prompt"
+    mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
+
+    llm = OpenAILlm(config)
+    answer = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    mocked_get_answer.assert_called_once_with("Test query", config)
+
+
+def test_get_llm_model_answer_empty_prompt(config, mocker):
+    mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
+
+    llm = OpenAILlm(config)
+    answer = llm.get_llm_model_answer("")
+
+    assert answer == "Test answer"
+    mocked_get_answer.assert_called_once_with("", config)
+
+
+def test_get_llm_model_answer_with_streaming(config, mocker):
+    config.stream = True
+    mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
+
+    llm = OpenAILlm(config)
+    llm.get_llm_model_answer("Test query")
+
+    mocked_jinachat.assert_called_once()
+    callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
+    assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
+
+
+def test_get_llm_model_answer_without_system_prompt(config, mocker):
+    config.system_prompt = None
+    mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
+
+    llm = OpenAILlm(config)
+    llm.get_llm_model_answer("Test query")
+
+    mocked_jinachat.assert_called_once_with(
+        model=config.model,
+        temperature=config.temperature,
+        max_tokens=config.max_tokens,
+        model_kwargs={"top_p": config.top_p},
+    )

+ 78 - 128
tests/llm/test_query.py

@@ -1,135 +1,85 @@
 import os
-import unittest
+import pytest
 from unittest.mock import MagicMock, patch
-
 from embedchain import App
 from embedchain.config import AppConfig, BaseLlmConfig
 
 
-class TestApp(unittest.TestCase):
-    os.environ["OPENAI_API_KEY"] = "test_key"
-
-    def setUp(self):
-        self.app = App(config=AppConfig(collect_metrics=False))
-
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_query(self):
-        """
-        This test checks the functionality of the 'query' method in the App class.
-        It simulates a scenario where the 'retrieve_from_database' method returns a context list and
-        'get_llm_model_answer' returns an expected answer string.
-
-        The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
-        appropriately and return the right answer.
-
-        Key assumptions tested:
-        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
-            BaseLlmConfig.
-        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
-        - 'query' method returns the value it received from 'get_llm_model_answer'.
-
-        The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
-        'get_llm_model_answer' methods.
-        """
-        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
-            mock_retrieve.return_value = ["Test context"]
-            with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
-                mock_answer.return_value = "Test answer"
-                _answer = self.app.query(input_query="Test query")
-
-        # Ensure retrieve_from_database was called
-        mock_retrieve.assert_called_once()
-
-        # Check the call arguments
-        args, kwargs = mock_retrieve.call_args
-        input_query_arg = kwargs.get("input_query")
-        self.assertEqual(input_query_arg, "Test query")
-        mock_answer.assert_called_once()
-
-    @patch("embedchain.llm.openai.OpenAILlm._get_answer")
-    def test_query_config_app_passing(self, mock_get_answer):
-        mock_get_answer.return_value = MagicMock()
-        mock_get_answer.return_value.content = "Test answer"
-
-        config = AppConfig(collect_metrics=False)
-        chat_config = BaseLlmConfig(system_prompt="Test system prompt")
-        app = App(config=config, llm_config=chat_config)
-        answer = app.llm.get_llm_model_answer("Test query")
-
-        self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
-        self.assertEqual(answer, "Test answer")
-
-    @patch("embedchain.llm.openai.OpenAILlm._get_answer")
-    def test_app_passing(self, mock_get_answer):
-        mock_get_answer.return_value = MagicMock()
-        mock_get_answer.return_value.content = "Test answer"
-        config = AppConfig(collect_metrics=False)
-        chat_config = BaseLlmConfig()
-        app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
-        answer = app.llm.get_llm_model_answer("Test query")
-        self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
-        self.assertEqual(answer, "Test answer")
-
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_query_with_where_in_params(self):
-        """
-        This test checks the functionality of the 'query' method in the App class.
-        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
-        a where filter and 'get_llm_model_answer' returns an expected answer string.
-
-        The 'query' method is expected to call 'retrieve_from_database' with the where filter  and
-        'get_llm_model_answer' methods appropriately and return the right answer.
-
-        Key assumptions tested:
-        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
-            BaseLlmConfig.
-        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
-        - 'query' method returns the value it received from 'get_llm_model_answer'.
-
-        The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
-        'get_llm_model_answer' methods.
-        """
-        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
-            mock_retrieve.return_value = ["Test context"]
-            with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
-                mock_answer.return_value = "Test answer"
-                answer = self.app.query("Test query", where={"attribute": "value"})
-
-        self.assertEqual(answer, "Test answer")
-        _args, kwargs = mock_retrieve.call_args
-        self.assertEqual(kwargs.get("input_query"), "Test query")
-        self.assertEqual(kwargs.get("where"), {"attribute": "value"})
-        mock_answer.assert_called_once()
-
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_query_with_where_in_query_config(self):
-        """
-        This test checks the functionality of the 'query' method in the App class.
-        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
-        a where filter and 'get_llm_model_answer' returns an expected answer string.
-
-        The 'query' method is expected to call 'retrieve_from_database' with the where filter  and
-        'get_llm_model_answer' methods appropriately and return the right answer.
-
-        Key assumptions tested:
-        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
-            BaseLlmConfig.
-        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
-        - 'query' method returns the value it received from 'get_llm_model_answer'.
-
-        The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
-        'get_llm_model_answer' methods.
-        """
-
-        with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
+@pytest.fixture
+def app():
+    os.environ["OPENAI_API_KEY"] = "test_api_key"
+    app = App(config=AppConfig(collect_metrics=False))
+    return app
+
+
+@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+def test_query(app):
+    with patch.object(app, "retrieve_from_database") as mock_retrieve:
+        mock_retrieve.return_value = ["Test context"]
+        with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
+            mock_answer.return_value = "Test answer"
+            answer = app.query(input_query="Test query")
+            assert answer == "Test answer"
+
+    mock_retrieve.assert_called_once()
+    _, kwargs = mock_retrieve.call_args
+    input_query_arg = kwargs.get("input_query")
+    assert input_query_arg == "Test query"
+    mock_answer.assert_called_once()
+
+
+@patch("embedchain.llm.openai.OpenAILlm._get_answer")
+def test_query_config_app_passing(mock_get_answer):
+    mock_get_answer.return_value = MagicMock()
+    mock_get_answer.return_value = "Test answer"
+
+    config = AppConfig(collect_metrics=False)
+    chat_config = BaseLlmConfig(system_prompt="Test system prompt")
+    app = App(config=config, llm_config=chat_config)
+    answer = app.llm.get_llm_model_answer("Test query")
+
+    assert app.llm.config.system_prompt == "Test system prompt"
+    assert answer == "Test answer"
+
+
+@patch("embedchain.llm.openai.OpenAILlm._get_answer")
+def test_app_passing(mock_get_answer):
+    mock_get_answer.return_value = MagicMock()
+    mock_get_answer.return_value = "Test answer"
+    config = AppConfig(collect_metrics=False)
+    chat_config = BaseLlmConfig()
+    app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
+    answer = app.llm.get_llm_model_answer("Test query")
+    assert app.llm.config.system_prompt == "Test system prompt"
+    assert answer == "Test answer"
+
+
+@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+def test_query_with_where_in_params(app):
+    with patch.object(app, "retrieve_from_database") as mock_retrieve:
+        mock_retrieve.return_value = ["Test context"]
+        with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
             mock_answer.return_value = "Test answer"
-            with patch.object(self.app.db, "query") as mock_database_query:
-                mock_database_query.return_value = ["Test context"]
-                llm_config = BaseLlmConfig(where={"attribute": "value"})
-                answer = self.app.query("Test query", llm_config)
-
-        self.assertEqual(answer, "Test answer")
-        _args, kwargs = mock_database_query.call_args
-        self.assertEqual(kwargs.get("input_query"), "Test query")
-        self.assertEqual(kwargs.get("where"), {"attribute": "value"})
-        mock_answer.assert_called_once()
+            answer = app.query("Test query", where={"attribute": "value"})
+
+    assert answer == "Test answer"
+    _, kwargs = mock_retrieve.call_args
+    assert kwargs.get("input_query") == "Test query"
+    assert kwargs.get("where") == {"attribute": "value"}
+    mock_answer.assert_called_once()
+
+
+@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+def test_query_with_where_in_query_config(app):
+    with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
+        mock_answer.return_value = "Test answer"
+        with patch.object(app.db, "query") as mock_database_query:
+            mock_database_query.return_value = ["Test context"]
+            llm_config = BaseLlmConfig(where={"attribute": "value"})
+            answer = app.query("Test query", llm_config)
+
+    assert answer == "Test answer"
+    _, kwargs = mock_database_query.call_args
+    assert kwargs.get("input_query") == "Test query"
+    assert kwargs.get("where") == {"attribute": "value"}
+    mock_answer.assert_called_once()

+ 23 - 26
tests/models/test_data_type.py

@@ -1,32 +1,29 @@
-import unittest
+from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
 
-from embedchain.models.data_type import (DataType, DirectDataType,
-                                         IndirectDataType, SpecialDataType)
 
+def test_subclass_types_in_data_type():
+    """Test that all data type category subclasses are contained in the composite data type"""
+    # Check if DirectDataType values are in DataType
+    for data_type in DirectDataType:
+        assert data_type.value in DataType._value2member_map_
 
-class TestDataTypeEnums(unittest.TestCase):
-    def test_subclass_types_in_data_type(self):
-        """Test that all data type category subclasses are contained in the composite data type"""
-        # Check if DirectDataType values are in DataType
-        for data_type in DirectDataType:
-            self.assertIn(data_type.value, DataType._value2member_map_)
+    # Check if IndirectDataType values are in DataType
+    for data_type in IndirectDataType:
+        assert data_type.value in DataType._value2member_map_
 
-        # Check if IndirectDataType values are in DataType
-        for data_type in IndirectDataType:
-            self.assertIn(data_type.value, DataType._value2member_map_)
+    # Check if SpecialDataType values are in DataType
+    for data_type in SpecialDataType:
+        assert data_type.value in DataType._value2member_map_
 
-        # Check if SpecialDataType values are in DataType
-        for data_type in SpecialDataType:
-            self.assertIn(data_type.value, DataType._value2member_map_)
 
-    def test_data_type_in_subclasses(self):
-        """Test that all data types in the composite data type are categorized in a subclass"""
-        for data_type in DataType:
-            if data_type.value in DirectDataType._value2member_map_:
-                self.assertIn(data_type.value, DirectDataType._value2member_map_)
-            elif data_type.value in IndirectDataType._value2member_map_:
-                self.assertIn(data_type.value, IndirectDataType._value2member_map_)
-            elif data_type.value in SpecialDataType._value2member_map_:
-                self.assertIn(data_type.value, SpecialDataType._value2member_map_)
-            else:
-                self.fail(f"{data_type.value} not found in any subclass enums")
+def test_data_type_in_subclasses():
+    """Test that all data types in the composite data type are categorized in a subclass"""
+    for data_type in DataType:
+        if data_type.value in DirectDataType._value2member_map_:
+            assert data_type.value in DirectDataType._value2member_map_
+        elif data_type.value in IndirectDataType._value2member_map_:
+            assert data_type.value in IndirectDataType._value2member_map_
+        elif data_type.value in SpecialDataType._value2member_map_:
+            assert data_type.value in SpecialDataType._value2member_map_
+        else:
+            assert False, f"{data_type.value} not found in any subclass enums"