浏览代码

Improve tests (#800)

Sidharth Mohanty 1 年之前
父节点
当前提交
5ec12212e4
共有 6 个文件被更改,包括 280 次插入123 次删除
  1. 6 8
      embedchain/apps/person_app.py
  2. 58 60
      tests/apps/test_apps.py
  3. 80 0
      tests/apps/test_person_app.py
  4. 50 0
      tests/bots/test_poe.py
  5. 34 55
      tests/embedchain/test_add.py
  6. 52 0
      tests/llm/test_base_llm.py

+ 6 - 8
embedchain/apps/person_app.py

@@ -2,10 +2,8 @@ from string import Template
 
 
 from embedchain.apps.app import App
 from embedchain.apps.app import App
 from embedchain.apps.open_source_app import OpenSourceApp
 from embedchain.apps.open_source_app import OpenSourceApp
-from embedchain.config import BaseLlmConfig
-from embedchain.config.apps.base_app_config import BaseAppConfig
-from embedchain.config.llm.base import (DEFAULT_PROMPT,
-                                        DEFAULT_PROMPT_WITH_HISTORY)
+from embedchain.config import BaseLlmConfig, AppConfig
+from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
 from embedchain.helper.json_serializable import register_deserializable
 from embedchain.helper.json_serializable import register_deserializable
 
 
 
 
@@ -16,16 +14,16 @@ class EmbedChainPersonApp:
     This bot behaves and speaks like a person.
     This bot behaves and speaks like a person.
 
 
     :param person: name of the person, better if its a well known person.
     :param person: name of the person, better if its a well known person.
-    :param config: BaseAppConfig instance to load as configuration.
+    :param config: AppConfig instance to load as configuration.
     """
     """
 
 
-    def __init__(self, person: str, config: BaseAppConfig = None):
+    def __init__(self, person: str, config: AppConfig = None):
         """Initialize a new person app
         """Initialize a new person app
 
 
         :param person: Name of the person that's imitated.
         :param person: Name of the person that's imitated.
         :type person: str
         :type person: str
         :param config: Configuration class instance, defaults to None
         :param config: Configuration class instance, defaults to None
-        :type config: BaseAppConfig, optional
+        :type config: AppConfig, optional
         """
         """
         self.person = person
         self.person = person
         self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style."  # noqa:E501
         self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style."  # noqa:E501
@@ -70,7 +68,7 @@ class PersonApp(EmbedChainPersonApp, App):
     """
     """
 
 
     def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
     def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
-        config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
+        config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
         return super().query(input_query, config, dry_run, where=None)
         return super().query(input_query, config, dry_run, where=None)
 
 
     def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
     def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):

+ 58 - 60
tests/apps/test_apps.py

@@ -1,108 +1,106 @@
 import os
 import os
-import unittest
-
+import pytest
 import yaml
 import yaml
 
 
 from embedchain import App, CustomApp, Llama2App, OpenSourceApp
 from embedchain import App, CustomApp, Llama2App, OpenSourceApp
-from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
-                               BaseLlmConfig, ChromaDbConfig)
+from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.llm.base import BaseLlm
 from embedchain.llm.base import BaseLlm
 from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
 from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
 from embedchain.vectordb.chroma import ChromaDB
 from embedchain.vectordb.chroma import ChromaDB
 
 
 
 
-class TestApps(unittest.TestCase):
+@pytest.fixture
+def app():
+    os.environ["OPENAI_API_KEY"] = "test_api_key"
+    return App()
+
+
+@pytest.fixture
+def custom_app():
+    os.environ["OPENAI_API_KEY"] = "test_api_key"
+    return CustomApp()
+
+
+@pytest.fixture
+def opensource_app():
+    os.environ["OPENAI_API_KEY"] = "test_api_key"
+    return OpenSourceApp()
+
+
+@pytest.fixture
+def llama2_app():
     os.environ["OPENAI_API_KEY"] = "test_api_key"
     os.environ["OPENAI_API_KEY"] = "test_api_key"
+    os.environ["REPLICATE_API_TOKEN"] = "-"
+    return Llama2App()
 
 
-    def test_app(self):
-        app = App()
-        self.assertIsInstance(app.llm, BaseLlm)
-        self.assertIsInstance(app.db, BaseVectorDB)
-        self.assertIsInstance(app.embedder, BaseEmbedder)
 
 
-        wrong_llm = "wrong_llm"
-        with self.assertRaises(TypeError):
-            App(llm=wrong_llm)
+def test_app(app):
+    assert isinstance(app.llm, BaseLlm)
+    assert isinstance(app.db, BaseVectorDB)
+    assert isinstance(app.embedder, BaseEmbedder)
 
 
-        wrong_db = "wrong_db"
-        with self.assertRaises(TypeError):
-            App(db=wrong_db)
 
 
-        wrong_embedder = "wrong_embedder"
-        with self.assertRaises(TypeError):
-            App(embedder=wrong_embedder)
+def test_custom_app(custom_app):
+    assert isinstance(custom_app.llm, BaseLlm)
+    assert isinstance(custom_app.db, BaseVectorDB)
+    assert isinstance(custom_app.embedder, BaseEmbedder)
 
 
-    def test_custom_app(self):
-        app = CustomApp()
-        self.assertIsInstance(app.llm, BaseLlm)
-        self.assertIsInstance(app.db, BaseVectorDB)
-        self.assertIsInstance(app.embedder, BaseEmbedder)
 
 
-    def test_opensource_app(self):
-        app = OpenSourceApp()
-        self.assertIsInstance(app.llm, BaseLlm)
-        self.assertIsInstance(app.db, BaseVectorDB)
-        self.assertIsInstance(app.embedder, BaseEmbedder)
+def test_opensource_app(opensource_app):
+    assert isinstance(opensource_app.llm, BaseLlm)
+    assert isinstance(opensource_app.db, BaseVectorDB)
+    assert isinstance(opensource_app.embedder, BaseEmbedder)
 
 
-    def test_llama2_app(self):
-        os.environ["REPLICATE_API_TOKEN"] = "-"
-        app = Llama2App()
-        self.assertIsInstance(app.llm, BaseLlm)
-        self.assertIsInstance(app.db, BaseVectorDB)
-        self.assertIsInstance(app.embedder, BaseEmbedder)
 
 
+def test_llama2_app(llama2_app):
+    assert isinstance(llama2_app.llm, BaseLlm)
+    assert isinstance(llama2_app.db, BaseVectorDB)
+    assert isinstance(llama2_app.embedder, BaseEmbedder)
 
 
-class TestConfigForAppComponents(unittest.TestCase):
-    collection_name = "my-test-collection"
 
 
+class TestConfigForAppComponents:
     def test_constructor_config(self):
     def test_constructor_config(self):
-        """
-        Test that app can be configured through the app constructor.
-        """
-        app = App(db_config=ChromaDbConfig(collection_name=self.collection_name))
-        self.assertEqual(app.db.config.collection_name, self.collection_name)
+        collection_name = "my-test-collection"
+        app = App(db_config=ChromaDbConfig(collection_name=collection_name))
+        assert app.db.config.collection_name == collection_name
 
 
     def test_component_config(self):
     def test_component_config(self):
-        """
-        Test that app can also be configured by passing a configured component to the app
-        """
-        database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name))
+        collection_name = "my-test-collection"
+        database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
         app = App(db=database)
         app = App(db=database)
-        self.assertEqual(app.db.config.collection_name, self.collection_name)
+        assert app.db.config.collection_name == collection_name
 
 
     def test_different_configs_are_proper_instances(self):
     def test_different_configs_are_proper_instances(self):
-        config = AppConfig()
-        wrong_app_config = AddConfig()
-
-        with self.assertRaises(TypeError):
-            App(config=wrong_app_config)
+        app_config = AppConfig()
+        wrong_config = AddConfig()
+        with pytest.raises(TypeError):
+            App(config=wrong_config)
 
 
-        self.assertIsInstance(config, AppConfig)
+        assert isinstance(app_config, AppConfig)
 
 
         llm_config = BaseLlmConfig()
         llm_config = BaseLlmConfig()
         wrong_llm_config = "wrong_llm_config"
         wrong_llm_config = "wrong_llm_config"
 
 
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             App(llm_config=wrong_llm_config)
             App(llm_config=wrong_llm_config)
 
 
-        self.assertIsInstance(llm_config, BaseLlmConfig)
+        assert isinstance(llm_config, BaseLlmConfig)
 
 
         db_config = BaseVectorDbConfig()
         db_config = BaseVectorDbConfig()
         wrong_db_config = "wrong_db_config"
         wrong_db_config = "wrong_db_config"
 
 
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             App(db_config=wrong_db_config)
             App(db_config=wrong_db_config)
 
 
-        self.assertIsInstance(db_config, BaseVectorDbConfig)
+        assert isinstance(db_config, BaseVectorDbConfig)
 
 
         embedder_config = BaseEmbedderConfig()
         embedder_config = BaseEmbedderConfig()
         wrong_embedder_config = "wrong_embedder_config"
         wrong_embedder_config = "wrong_embedder_config"
-
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             App(embedder_config=wrong_embedder_config)
             App(embedder_config=wrong_embedder_config)
 
 
-        self.assertIsInstance(embedder_config, BaseEmbedderConfig)
+        assert isinstance(embedder_config, BaseEmbedderConfig)
 
 
 
 
 class TestAppFromConfig:
 class TestAppFromConfig:

+ 80 - 0
tests/apps/test_person_app.py

@@ -0,0 +1,80 @@
+import pytest
+from embedchain.apps.app import App
+from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp
+from embedchain.config import BaseLlmConfig, AppConfig
+from embedchain.config.llm.base import DEFAULT_PROMPT
+
+
+@pytest.fixture
+def person_app():
+    config = AppConfig()
+    return PersonApp("John Doe", config)
+
+
+@pytest.fixture
+def opensource_person_app():
+    config = AppConfig()
+    return PersonOpenSourceApp("John Doe", config)
+
+
+def test_person_app_initialization(person_app):
+    assert person_app.person == "John Doe"
+    assert f"You are {person_app.person}" in person_app.person_prompt
+    assert isinstance(person_app.config, AppConfig)
+
+
+def test_person_app_add_person_template_to_config_with_invalid_template():
+    app = PersonApp("John Doe")
+    default_prompt = "Input Prompt"
+    with pytest.raises(ValueError):
+        # as prompt doesn't contain $context and $query
+        app.add_person_template_to_config(default_prompt)
+
+
+def test_person_app_add_person_template_to_config_with_valid_template():
+    app = PersonApp("John Doe")
+    config = app.add_person_template_to_config(DEFAULT_PROMPT)
+    assert (
+        config.template.template
+        == f"You are John Doe. Whatever you say, you will always say in John Doe style. {DEFAULT_PROMPT}"
+    )
+
+
+def test_person_app_query(mocker, person_app):
+    input_query = "Hello, how are you?"
+    config = BaseLlmConfig()
+
+    mocker.patch.object(App, "query", return_value="Mocked response")
+
+    result = person_app.query(input_query, config)
+    assert result == "Mocked response"
+
+
+def test_person_app_chat(mocker, person_app):
+    input_query = "Hello, how are you?"
+    config = BaseLlmConfig()
+
+    mocker.patch.object(App, "chat", return_value="Mocked chat response")
+
+    result = person_app.chat(input_query, config)
+    assert result == "Mocked chat response"
+
+
+def test_opensource_person_app_query(mocker, opensource_person_app):
+    input_query = "Hello, how are you?"
+    config = BaseLlmConfig()
+
+    mocker.patch.object(App, "query", return_value="Mocked response")
+
+    result = opensource_person_app.query(input_query, config)
+    assert result == "Mocked response"
+
+
+def test_opensource_person_app_chat(mocker, opensource_person_app):
+    input_query = "Hello, how are you?"
+    config = BaseLlmConfig()
+
+    mocker.patch.object(App, "chat", return_value="Mocked chat response")
+
+    result = opensource_person_app.chat(input_query, config)
+    assert result == "Mocked chat response"

+ 50 - 0
tests/bots/test_poe.py

@@ -0,0 +1,50 @@
+import argparse
+import pytest
+
+from embedchain.bots.poe import PoeBot, start_command
+from fastapi_poe.types import QueryRequest, ProtocolMessage
+
+
+@pytest.fixture
+def poe_bot(mocker):
+    bot = PoeBot()
+    mocker.patch("fastapi_poe.run")
+    return bot
+
+
+@pytest.mark.asyncio
+async def test_poe_bot_get_response(poe_bot, mocker):
+    query = QueryRequest(
+        version="test",
+        type="query",
+        query=[ProtocolMessage(role="system", content="Test content")],
+        user_id="test_user_id",
+        conversation_id="test_conversation_id",
+        message_id="test_message_id",
+    )
+
+    mocker.patch.object(poe_bot.app.llm, "set_history")
+
+    response_generator = poe_bot.get_response(query)
+
+    await response_generator.__anext__()
+    poe_bot.app.llm.set_history.assert_called_once()
+
+
+def test_poe_bot_handle_message(poe_bot, mocker):
+    mocker.patch.object(poe_bot, "ask_bot", return_value="Answer from the bot")
+
+    response_ask = poe_bot.handle_message("What is the answer?")
+    assert response_ask == "Answer from the bot"
+
+    # TODO: This test will fail because the add_data method is commented out.
+    # mocker.patch.object(poe_bot, 'add_data', return_value="Added data from: some_data")
+    # response_add = poe_bot.handle_message("/add some_data")
+    # assert response_add == "Added data from: some_data"
+
+
+def test_start_command(mocker):
+    mocker.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(api_key="test_api_key"))
+    mocker.patch("embedchain.bots.poe.run")
+
+    start_command()

+ 34 - 55
tests/embedchain/test_add.py

@@ -1,70 +1,49 @@
 import os
 import os
-import unittest
-from unittest.mock import MagicMock, patch
-
+import pytest
 from embedchain import App
 from embedchain import App
 from embedchain.config import AddConfig, AppConfig, ChunkerConfig
 from embedchain.config import AddConfig, AppConfig, ChunkerConfig
 from embedchain.models.data_type import DataType
 from embedchain.models.data_type import DataType
 
 
+os.environ["OPENAI_API_KEY"] = "test_key"
+
+
+@pytest.fixture
+def app(mocker):
+    mocker.patch("chromadb.api.models.Collection.Collection.add")
+    return App(config=AppConfig(collect_metrics=False))
+
 
 
-class TestApp(unittest.TestCase):
-    os.environ["OPENAI_API_KEY"] = "test_key"
+def test_add(app):
+    app.add("https://example.com", metadata={"meta": "meta-data"})
+    assert app.user_asks == [["https://example.com", "web_page", {"meta": "meta-data"}]]
 
 
-    def setUp(self):
-        self.app = App(config=AppConfig(collect_metrics=False))
 
 
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_add(self):
-        """
-        This test checks the functionality of the 'add' method in the App class.
-        It begins by simulating the addition of a web page with a specific URL to the application instance.
-        The 'add' method is expected to append the input type and URL to the 'user_asks' attribute of the App instance.
-        By asserting that 'user_asks' is updated correctly after the 'add' method is called, we can confirm that the
-        method is working as intended.
-        The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the
-        'add' method.
-        """
-        self.app.add("https://example.com", metadata={"meta": "meta-data"})
-        self.assertEqual(self.app.user_asks, [["https://example.com", "web_page", {"meta": "meta-data"}]])
+def test_add_sitemap(app):
+    app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
+    assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]]
 
 
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_add_sitemap(self):
-        """
-        In addition to the test_add function, this test checks that sitemaps can be added with the correct data type.
-        """
-        self.app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
-        self.assertEqual(self.app.user_asks, [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]])
 
 
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_add_forced_type(self):
-        """
-        Test that you can also force a data_type with `add`.
-        """
-        data_type = "text"
-        self.app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
-        self.assertEqual(self.app.user_asks, [["https://example.com", data_type, {"meta": "meta-data"}]])
+def test_add_forced_type(app):
+    data_type = "text"
+    app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
+    assert app.user_asks == [["https://example.com", data_type, {"meta": "meta-data"}]]
 
 
-    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
-    def test_dry_run(self):
-        """
-        Test that if dry_run == True then data chunks are returned.
-        """
 
 
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
-        # We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
-        text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
+def test_dry_run(app):
+    chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
+    text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
 
 
-        result = self.app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
+    result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
 
 
-        chunks = result["chunks"]
-        metadata = result["metadata"]
-        count = result["count"]
-        data_type = result["type"]
+    chunks = result["chunks"]
+    metadata = result["metadata"]
+    count = result["count"]
+    data_type = result["type"]
 
 
-        self.assertEqual(len(chunks), len(text))
-        self.assertEqual(count, len(text))
-        self.assertEqual(data_type, DataType.TEXT)
-        for item in metadata:
-            self.assertIsInstance(item, dict)
-            self.assertIn(item["url"], "local")
-            self.assertIn(item["data_type"], "text")
+    assert len(chunks) == len(text)
+    assert count == len(text)
+    assert data_type == DataType.TEXT
+    for item in metadata:
+        assert isinstance(item, dict)
+        assert "local" in item["url"]
+        assert "text" in item["data_type"]

+ 52 - 0
tests/llm/test_base_llm.py

@@ -0,0 +1,52 @@
+import pytest
+from embedchain.llm.base import BaseLlm, BaseLlmConfig
+
+
+@pytest.fixture
+def base_llm():
+    config = BaseLlmConfig()
+    return BaseLlm(config=config)
+
+
+def test_is_get_llm_model_answer_not_implemented(base_llm):
+    with pytest.raises(NotImplementedError):
+        base_llm.get_llm_model_answer()
+
+
+def test_is_get_llm_model_answer_implemented():
+    class TestLlm(BaseLlm):
+        def get_llm_model_answer(self):
+            return "Implemented"
+
+    config = BaseLlmConfig()
+    llm = TestLlm(config=config)
+    assert llm.get_llm_model_answer() == "Implemented"
+
+
+def test_stream_query_response(base_llm):
+    answer = ["Chunk1", "Chunk2", "Chunk3"]
+    result = list(base_llm._stream_query_response(answer))
+    assert result == answer
+
+
+def test_stream_chat_response(base_llm):
+    answer = ["Chunk1", "Chunk2", "Chunk3"]
+    result = list(base_llm._stream_chat_response(answer))
+    assert result == answer
+
+
+def test_append_search_and_context(base_llm):
+    context = "Context"
+    web_search_result = "Web Search Result"
+    result = base_llm._append_search_and_context(context, web_search_result)
+    expected_result = "Context\nWeb Search Result: Web Search Result"
+    assert result == expected_result
+
+
+def test_access_search_and_get_results(base_llm, mocker):
+    base_llm.access_search_and_get_results = mocker.patch.object(
+        base_llm, "access_search_and_get_results", return_value="Search Results"
+    )
+    input_query = "Test query"
+    result = base_llm.access_search_and_get_results(input_query)
+    assert result == "Search Results"