Browse Source

Improve tests (#800)

Sidharth Mohanty 1 year ago
parent
commit
5ec12212e4

+ 6 - 8
embedchain/apps/person_app.py

@@ -2,10 +2,8 @@ from string import Template
 
 from embedchain.apps.app import App
 from embedchain.apps.open_source_app import OpenSourceApp
-from embedchain.config import BaseLlmConfig
-from embedchain.config.apps.base_app_config import BaseAppConfig
-from embedchain.config.llm.base import (DEFAULT_PROMPT,
-                                        DEFAULT_PROMPT_WITH_HISTORY)
+from embedchain.config import BaseLlmConfig, AppConfig
+from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
 from embedchain.helper.json_serializable import register_deserializable
 
 
@@ -16,16 +14,16 @@ class EmbedChainPersonApp:
     This bot behaves and speaks like a person.
 
     :param person: name of the person, better if its a well known person.
-    :param config: BaseAppConfig instance to load as configuration.
+    :param config: AppConfig instance to load as configuration.
     """
 
-    def __init__(self, person: str, config: BaseAppConfig = None):
+    def __init__(self, person: str, config: AppConfig = None):
         """Initialize a new person app
 
         :param person: Name of the person that's imitated.
         :type person: str
         :param config: Configuration class instance, defaults to None
-        :type config: BaseAppConfig, optional
+        :type config: AppConfig, optional
         """
         self.person = person
         self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style."  # noqa:E501
@@ -70,7 +68,7 @@ class PersonApp(EmbedChainPersonApp, App):
     """
 
     def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
-        config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
+        config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
         return super().query(input_query, config, dry_run, where=None)
 
     def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):

+ 58 - 60
tests/apps/test_apps.py

@@ -1,108 +1,106 @@
 import os
-import unittest
-
+import pytest
 import yaml
 
 from embedchain import App, CustomApp, Llama2App, OpenSourceApp
-from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
-                               BaseLlmConfig, ChromaDbConfig)
+from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.llm.base import BaseLlm
 from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
 from embedchain.vectordb.chroma import ChromaDB
 
 
-class TestApps(unittest.TestCase):
+@pytest.fixture
+def app():
+    os.environ["OPENAI_API_KEY"] = "test_api_key"
+    return App()
+
+
+@pytest.fixture
+def custom_app():
+    os.environ["OPENAI_API_KEY"] = "test_api_key"
+    return CustomApp()
+
+
+@pytest.fixture
+def opensource_app():
+    os.environ["OPENAI_API_KEY"] = "test_api_key"
+    return OpenSourceApp()
+
+
+@pytest.fixture
+def llama2_app():
     os.environ["OPENAI_API_KEY"] = "test_api_key"
+    os.environ["REPLICATE_API_TOKEN"] = "-"
+    return Llama2App()
 
-    def test_app(self):
-        app = App()
-        self.assertIsInstance(app.llm, BaseLlm)
-        self.assertIsInstance(app.db, BaseVectorDB)
-        self.assertIsInstance(app.embedder, BaseEmbedder)
 
-        wrong_llm = "wrong_llm"
-        with self.assertRaises(TypeError):
-            App(llm=wrong_llm)
+def test_app(app):
+    assert isinstance(app.llm, BaseLlm)
+    assert isinstance(app.db, BaseVectorDB)
+    assert isinstance(app.embedder, BaseEmbedder)
 
-        wrong_db = "wrong_db"
-        with self.assertRaises(TypeError):
-            App(db=wrong_db)
 
-        wrong_embedder = "wrong_embedder"
-        with self.assertRaises(TypeError):
-            App(embedder=wrong_embedder)
+def test_custom_app(custom_app):
+    assert isinstance(custom_app.llm, BaseLlm)
+    assert isinstance(custom_app.db, BaseVectorDB)
+    assert isinstance(custom_app.embedder, BaseEmbedder)
 
-    def test_custom_app(self):
-        app = CustomApp()
-        self.assertIsInstance(app.llm, BaseLlm)
-        self.assertIsInstance(app.db, BaseVectorDB)
-        self.assertIsInstance(app.embedder, BaseEmbedder)
 
-    def test_opensource_app(self):
-        app = OpenSourceApp()
-        self.assertIsInstance(app.llm, BaseLlm)
-        self.assertIsInstance(app.db, BaseVectorDB)
-        self.assertIsInstance(app.embedder, BaseEmbedder)
+def test_opensource_app(opensource_app):
+    assert isinstance(opensource_app.llm, BaseLlm)
+    assert isinstance(opensource_app.db, BaseVectorDB)
+    assert isinstance(opensource_app.embedder, BaseEmbedder)
 
-    def test_llama2_app(self):
-        os.environ["REPLICATE_API_TOKEN"] = "-"
-        app = Llama2App()
-        self.assertIsInstance(app.llm, BaseLlm)
-        self.assertIsInstance(app.db, BaseVectorDB)
-        self.assertIsInstance(app.embedder, BaseEmbedder)
 
+def test_llama2_app(llama2_app):
+    assert isinstance(llama2_app.llm, BaseLlm)
+    assert isinstance(llama2_app.db, BaseVectorDB)
+    assert isinstance(llama2_app.embedder, BaseEmbedder)
 
-class TestConfigForAppComponents(unittest.TestCase):
-    collection_name = "my-test-collection"
 
+class TestConfigForAppComponents:
     def test_constructor_config(self):
-        """
-        Test that app can be configured through the app constructor.
-        """
-        app = App(db_config=ChromaDbConfig(collection_name=self.collection_name))
-        self.assertEqual(app.db.config.collection_name, self.collection_name)
+        collection_name = "my-test-collection"
+        app = App(db_config=ChromaDbConfig(collection_name=collection_name))
+        assert app.db.config.collection_name == collection_name
 
     def test_component_config(self):
-        """
-        Test that app can also be configured by passing a configured component to the app
-        """
-        database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name))
+        collection_name = "my-test-collection"
+        database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
         app = App(db=database)
-        self.assertEqual(app.db.config.collection_name, self.collection_name)
+        assert app.db.config.collection_name == collection_name
 
     def test_different_configs_are_proper_instances(self):
-        config = AppConfig()
-        wrong_app_config = AddConfig()
-
-        with self.assertRaises(TypeError):
-            App(config=wrong_app_config)
+        app_config = AppConfig()
+        wrong_config = AddConfig()
+        with pytest.raises(TypeError):
+            App(config=wrong_config)
 
-        self.assertIsInstance(config, AppConfig)
+        assert isinstance(app_config, AppConfig)
 
         llm_config = BaseLlmConfig()
         wrong_llm_config = "wrong_llm_config"
 
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             App(llm_config=wrong_llm_config)
 
-        self.assertIsInstance(llm_config, BaseLlmConfig)
+        assert isinstance(llm_config, BaseLlmConfig)
 
         db_config = BaseVectorDbConfig()
         wrong_db_config = "wrong_db_config"
 
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             App(db_config=wrong_db_config)
 
-        self.assertIsInstance(db_config, BaseVectorDbConfig)
+        assert isinstance(db_config, BaseVectorDbConfig)
 
         embedder_config = BaseEmbedderConfig()
         wrong_embedder_config = "wrong_embedder_config"
-
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             App(embedder_config=wrong_embedder_config)
 
-        self.assertIsInstance(embedder_config, BaseEmbedderConfig)
+        assert isinstance(embedder_config, BaseEmbedderConfig)
 
 
 class TestAppFromConfig:

+ 80 - 0
tests/apps/test_person_app.py

@@ -0,0 +1,80 @@
+import pytest
+from embedchain.apps.app import App
+from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp
+from embedchain.config import BaseLlmConfig, AppConfig
+from embedchain.config.llm.base import DEFAULT_PROMPT
+
+
+@pytest.fixture
+def person_app():
+    config = AppConfig()
+    return PersonApp("John Doe", config)
+
+
+@pytest.fixture
+def opensource_person_app():
+    config = AppConfig()
+    return PersonOpenSourceApp("John Doe", config)
+
+
+def test_person_app_initialization(person_app):
+    assert person_app.person == "John Doe"
+    assert f"You are {person_app.person}" in person_app.person_prompt
+    assert isinstance(person_app.config, AppConfig)
+
+
+def test_person_app_add_person_template_to_config_with_invalid_template():
+    app = PersonApp("John Doe")
+    default_prompt = "Input Prompt"
+    with pytest.raises(ValueError):
+        # as prompt doesn't contain $context and $query
+        app.add_person_template_to_config(default_prompt)
+
+
+def test_person_app_add_person_template_to_config_with_valid_template():
+    app = PersonApp("John Doe")
+    config = app.add_person_template_to_config(DEFAULT_PROMPT)
+    assert (
+        config.template.template
+        == f"You are John Doe. Whatever you say, you will always say in John Doe style. {DEFAULT_PROMPT}"
+    )
+
+
+def test_person_app_query(mocker, person_app):
+    input_query = "Hello, how are you?"
+    config = BaseLlmConfig()
+
+    mocker.patch.object(App, "query", return_value="Mocked response")
+
+    result = person_app.query(input_query, config)
+    assert result == "Mocked response"
+
+
+def test_person_app_chat(mocker, person_app):
+    input_query = "Hello, how are you?"
+    config = BaseLlmConfig()
+
+    mocker.patch.object(App, "chat", return_value="Mocked chat response")
+
+    result = person_app.chat(input_query, config)
+    assert result == "Mocked chat response"
+
+
+def test_opensource_person_app_query(mocker, opensource_person_app):
+    input_query = "Hello, how are you?"
+    config = BaseLlmConfig()
+
+    mocker.patch.object(App, "query", return_value="Mocked response")
+
+    result = opensource_person_app.query(input_query, config)
+    assert result == "Mocked response"
+
+
+def test_opensource_person_app_chat(mocker, opensource_person_app):
+    input_query = "Hello, how are you?"
+    config = BaseLlmConfig()
+
+    mocker.patch.object(App, "chat", return_value="Mocked chat response")
+
+    result = opensource_person_app.chat(input_query, config)
+    assert result == "Mocked chat response"

+ 50 - 0
tests/bots/test_poe.py

@@ -0,0 +1,50 @@
+import argparse
+import pytest
+
+from embedchain.bots.poe import PoeBot, start_command
+from fastapi_poe.types import QueryRequest, ProtocolMessage
+
+
+@pytest.fixture
+def poe_bot(mocker):
+    bot = PoeBot()
+    mocker.patch("fastapi_poe.run")
+    return bot
+
+
+@pytest.mark.asyncio
+async def test_poe_bot_get_response(poe_bot, mocker):
+    query = QueryRequest(
+        version="test",
+        type="query",
+        query=[ProtocolMessage(role="system", content="Test content")],
+        user_id="test_user_id",
+        conversation_id="test_conversation_id",
+        message_id="test_message_id",
+    )
+
+    mocker.patch.object(poe_bot.app.llm, "set_history")
+
+    response_generator = poe_bot.get_response(query)
+
+    await response_generator.__anext__()
+    poe_bot.app.llm.set_history.assert_called_once()
+
+
+def test_poe_bot_handle_message(poe_bot, mocker):
+    mocker.patch.object(poe_bot, "ask_bot", return_value="Answer from the bot")
+
+    response_ask = poe_bot.handle_message("What is the answer?")
+    assert response_ask == "Answer from the bot"
+
+    # TODO: This test will fail because the add_data method is commented out.
+    # mocker.patch.object(poe_bot, 'add_data', return_value="Added data from: some_data")
+    # response_add = poe_bot.handle_message("/add some_data")
+    # assert response_add == "Added data from: some_data"
+
+
+def test_start_command(mocker):
+    mocker.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(api_key="test_api_key"))
+    mocker.patch("embedchain.bots.poe.run")
+
+    start_command()

+ 34 - 55
tests/embedchain/test_add.py

@@ -1,70 +1,49 @@
 import os
-import unittest
-from unittest.mock import MagicMock, patch
-
+import pytest
 from embedchain import App
 from embedchain.config import AddConfig, AppConfig, ChunkerConfig
 from embedchain.models.data_type import DataType
 
+os.environ["OPENAI_API_KEY"] = "test_key"
+
+
+@pytest.fixture
+def app(mocker):
+    mocker.patch("chromadb.api.models.Collection.Collection.add")
+    return App(config=AppConfig(collect_metrics=False))
+
 
-class TestApp(unittest.TestCase):
-    os.environ["OPENAI_API_KEY"] = "test_key"
+def test_add(app):
+    app.add("https://example.com", metadata={"meta": "meta-data"})
+    assert app.user_asks == [["https://example.com", "web_page", {"meta": "meta-data"}]]
 
-    def setUp(self):
-        self.app = App(config=AppConfig(collect_metrics=False))
 
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_add(self):
-        """
-        This test checks the functionality of the 'add' method in the App class.
-        It begins by simulating the addition of a web page with a specific URL to the application instance.
-        The 'add' method is expected to append the input type and URL to the 'user_asks' attribute of the App instance.
-        By asserting that 'user_asks' is updated correctly after the 'add' method is called, we can confirm that the
-        method is working as intended.
-        The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the
-        'add' method.
-        """
-        self.app.add("https://example.com", metadata={"meta": "meta-data"})
-        self.assertEqual(self.app.user_asks, [["https://example.com", "web_page", {"meta": "meta-data"}]])
+def test_add_sitemap(app):
+    app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
+    assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]]
 
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_add_sitemap(self):
-        """
-        In addition to the test_add function, this test checks that sitemaps can be added with the correct data type.
-        """
-        self.app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
-        self.assertEqual(self.app.user_asks, [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]])
 
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_add_forced_type(self):
-        """
-        Test that you can also force a data_type with `add`.
-        """
-        data_type = "text"
-        self.app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
-        self.assertEqual(self.app.user_asks, [["https://example.com", data_type, {"meta": "meta-data"}]])
+def test_add_forced_type(app):
+    data_type = "text"
+    app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
+    assert app.user_asks == [["https://example.com", data_type, {"meta": "meta-data"}]]
 
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_dry_run(self):
-        """
-        Test that if dry_run == True then data chunks are returned.
-        """
 
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
-        # We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
-        text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
+def test_dry_run(app):
+    chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
+    text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
 
-        result = self.app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
+    result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
 
-        chunks = result["chunks"]
-        metadata = result["metadata"]
-        count = result["count"]
-        data_type = result["type"]
+    chunks = result["chunks"]
+    metadata = result["metadata"]
+    count = result["count"]
+    data_type = result["type"]
 
-        self.assertEqual(len(chunks), len(text))
-        self.assertEqual(count, len(text))
-        self.assertEqual(data_type, DataType.TEXT)
-        for item in metadata:
-            self.assertIsInstance(item, dict)
-            self.assertIn(item["url"], "local")
-            self.assertIn(item["data_type"], "text")
+    assert len(chunks) == len(text)
+    assert count == len(text)
+    assert data_type == DataType.TEXT
+    for item in metadata:
+        assert isinstance(item, dict)
+        assert "local" in item["url"]
+        assert "text" in item["data_type"]

+ 52 - 0
tests/llm/test_base_llm.py

@@ -0,0 +1,52 @@
+import pytest
+from embedchain.llm.base import BaseLlm, BaseLlmConfig
+
+
+@pytest.fixture
+def base_llm():
+    config = BaseLlmConfig()
+    return BaseLlm(config=config)
+
+
+def test_is_get_llm_model_answer_not_implemented(base_llm):
+    with pytest.raises(NotImplementedError):
+        base_llm.get_llm_model_answer()
+
+
+def test_is_get_llm_model_answer_implemented():
+    class TestLlm(BaseLlm):
+        def get_llm_model_answer(self):
+            return "Implemented"
+
+    config = BaseLlmConfig()
+    llm = TestLlm(config=config)
+    assert llm.get_llm_model_answer() == "Implemented"
+
+
+def test_stream_query_response(base_llm):
+    answer = ["Chunk1", "Chunk2", "Chunk3"]
+    result = list(base_llm._stream_query_response(answer))
+    assert result == answer
+
+
+def test_stream_chat_response(base_llm):
+    answer = ["Chunk1", "Chunk2", "Chunk3"]
+    result = list(base_llm._stream_chat_response(answer))
+    assert result == answer
+
+
+def test_append_search_and_context(base_llm):
+    context = "Context"
+    web_search_result = "Web Search Result"
+    result = base_llm._append_search_and_context(context, web_search_result)
+    expected_result = "Context\nWeb Search Result: Web Search Result"
+    assert result == expected_result
+
+
+def test_access_search_and_get_results(base_llm, mocker):
+    base_llm.access_search_and_get_results = mocker.patch.object(
+        base_llm, "access_search_and_get_results", return_value="Search Results"
+    )
+    input_query = "Test query"
+    result = base_llm.access_search_and_get_results(input_query)
+    assert result == "Search Results"