فهرست منبع

Improve tests (#795)

Sidharth Mohanty 1 سال پیش
والد
کامیت
4820ea15d6

+ 4 - 5
embedchain/bots/base.py

@@ -1,10 +1,9 @@
 from typing import Any
 
-from embedchain import CustomApp
-from embedchain.config import AddConfig, CustomAppConfig, LlmConfig
+from embedchain import App
+from embedchain.config import AddConfig, AppConfig, LlmConfig
 from embedchain.embedder.openai import OpenAIEmbedder
-from embedchain.helper.json_serializable import (JSONSerializable,
-                                                 register_deserializable)
+from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
 from embedchain.llm.openai import OpenAILlm
 from embedchain.vectordb.chroma import ChromaDB
 
@@ -12,7 +11,7 @@ from embedchain.vectordb.chroma import ChromaDB
 @register_deserializable
 class BaseBot(JSONSerializable):
     def __init__(self):
-        self.app = CustomApp(config=CustomAppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAIEmbedder())
+        self.app = App(config=AppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAIEmbedder())
 
     def add(self, data: Any, config: AddConfig = None):
         """

+ 48 - 6
tests/apps/test_apps.py

@@ -2,18 +2,15 @@ import os
 import unittest
 
 from embedchain import App, CustomApp, Llama2App, OpenSourceApp
-from embedchain.config import ChromaDbConfig
+from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.llm.base import BaseLlm
-from embedchain.vectordb.base import BaseVectorDB
+from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
 from embedchain.vectordb.chroma import ChromaDB
 
 
 class TestApps(unittest.TestCase):
-    try:
-        del os.environ["OPENAI_KEY"]
-    except KeyError:
-        pass
+    os.environ["OPENAI_API_KEY"] = "test_api_key"
 
     def test_app(self):
         app = App()
@@ -21,6 +18,18 @@ class TestApps(unittest.TestCase):
         self.assertIsInstance(app.db, BaseVectorDB)
         self.assertIsInstance(app.embedder, BaseEmbedder)
 
+        wrong_llm = "wrong_llm"
+        with self.assertRaises(TypeError):
+            App(llm=wrong_llm)
+
+        wrong_db = "wrong_db"
+        with self.assertRaises(TypeError):
+            App(db=wrong_db)
+
+        wrong_embedder = "wrong_embedder"
+        with self.assertRaises(TypeError):
+            App(embedder=wrong_embedder)
+
     def test_custom_app(self):
         app = CustomApp()
         self.assertIsInstance(app.llm, BaseLlm)
@@ -58,3 +67,36 @@ class TestConfigForAppComponents(unittest.TestCase):
         database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name))
         app = App(db=database)
         self.assertEqual(app.db.config.collection_name, self.collection_name)
+
+    def test_different_configs_are_proper_instances(self):
+        config = AppConfig()
+        wrong_app_config = AddConfig()
+
+        with self.assertRaises(TypeError):
+            App(config=wrong_app_config)
+
+        self.assertIsInstance(config, AppConfig)
+
+        llm_config = BaseLlmConfig()
+        wrong_llm_config = "wrong_llm_config"
+
+        with self.assertRaises(TypeError):
+            App(llm_config=wrong_llm_config)
+
+        self.assertIsInstance(llm_config, BaseLlmConfig)
+
+        db_config = BaseVectorDbConfig()
+        wrong_db_config = "wrong_db_config"
+
+        with self.assertRaises(TypeError):
+            App(db_config=wrong_db_config)
+
+        self.assertIsInstance(db_config, BaseVectorDbConfig)
+
+        embedder_config = BaseEmbedderConfig()
+        wrong_embedder_config = "wrong_embedder_config"
+
+        with self.assertRaises(TypeError):
+            App(embedder_config=wrong_embedder_config)
+
+        self.assertIsInstance(embedder_config, BaseEmbedderConfig)

+ 49 - 0
tests/bots/test_base.py

@@ -0,0 +1,49 @@
+import os
+import pytest
+from embedchain.config import AddConfig, BaseLlmConfig
+from embedchain.bots.base import BaseBot
+from unittest.mock import patch
+
+
+@pytest.fixture
+def base_bot():
+    os.environ["OPENAI_API_KEY"] = "test_api_key"  # needed by App
+    return BaseBot()
+
+
+def test_add(base_bot):
+    data = "Test data"
+    config = AddConfig()
+
+    with patch.object(base_bot.app, "add") as mock_add:
+        base_bot.add(data, config)
+        mock_add.assert_called_with(data, config=config)
+
+
+def test_query(base_bot):
+    query = "Test query"
+    config = BaseLlmConfig()
+
+    with patch.object(base_bot.app, "query") as mock_query:
+        mock_query.return_value = "Query result"
+
+        result = base_bot.query(query, config)
+
+    assert isinstance(result, str)
+    assert result == "Query result"
+
+
+def test_start():
+    class TestBot(BaseBot):
+        def start(self):
+            return "Bot started"
+
+    bot = TestBot()
+    result = bot.start()
+    assert result == "Bot started"
+
+
+def test_start_not_implemented():
+    bot = BaseBot()
+    with pytest.raises(NotImplementedError):
+        bot.start()

+ 54 - 8
tests/embedder/test_embedder.py

@@ -1,11 +1,57 @@
-import unittest
-
+import pytest
+from unittest.mock import MagicMock
 from embedchain.embedder.base import BaseEmbedder
+from embedchain.config.embedder.base import BaseEmbedderConfig
+from chromadb.api.types import Documents, Embeddings
+
+
+@pytest.fixture
+def base_embedder():
+    return BaseEmbedder()
+
+
+def test_initialization(base_embedder):
+    assert isinstance(base_embedder.config, BaseEmbedderConfig)
+    # not initialized
+    assert not hasattr(base_embedder, "embedding_fn")
+    assert not hasattr(base_embedder, "vector_dimension")
+
+
+def test_set_embedding_fn(base_embedder):
+    def embedding_function(texts: Documents) -> Embeddings:
+        return [f"Embedding for {text}" for text in texts]
+
+    base_embedder.set_embedding_fn(embedding_function)
+    assert hasattr(base_embedder, "embedding_fn")
+    assert callable(base_embedder.embedding_fn)
+    embeddings = base_embedder.embedding_fn(["text1", "text2"])
+    assert embeddings == ["Embedding for text1", "Embedding for text2"]
+
+
+def test_set_embedding_fn_when_not_a_function(base_embedder):
+    with pytest.raises(ValueError):
+        base_embedder.set_embedding_fn(None)
+
+
+def test_set_vector_dimension(base_embedder):
+    base_embedder.set_vector_dimension(256)
+    assert hasattr(base_embedder, "vector_dimension")
+    assert base_embedder.vector_dimension == 256
+
+
+def test_set_vector_dimension_type_error(base_embedder):
+    with pytest.raises(TypeError):
+        base_embedder.set_vector_dimension(None)
+
+
+def test_langchain_default_concept():
+    embeddings = MagicMock()
+    embeddings.embed_documents.return_value = ["Embedding1", "Embedding2"]
+    embed_function = BaseEmbedder._langchain_default_concept(embeddings)
+    result = embed_function(["text1", "text2"])
+    assert result == ["Embedding1", "Embedding2"]
 
 
-class TestEmbedder(unittest.TestCase):
-    def test_init_with_invalid_vector_dim(self):
-        # Test if an exception is raised when an invalid vector_dim is provided
-        embedder = BaseEmbedder()
-        with self.assertRaises(TypeError):
-            embedder.set_vector_dimension(None)
+def test_embedder_with_config():
+    embedder = BaseEmbedder(BaseEmbedderConfig())
+    assert isinstance(embedder.config, BaseEmbedderConfig)

+ 64 - 0
tests/llm/test_antrophic.py

@@ -0,0 +1,64 @@
+import pytest
+from unittest.mock import MagicMock, patch
+
+from embedchain.llm.antrophic import AntrophicLlm
+from embedchain.config import BaseLlmConfig
+from langchain.schema import HumanMessage, SystemMessage
+
+
+@pytest.fixture
+def antrophic_llm():
+    config = BaseLlmConfig(temperature=0.5, model="gpt2")
+    return AntrophicLlm(config)
+
+
+def test_get_llm_model_answer(antrophic_llm):
+    with patch.object(AntrophicLlm, "_get_answer", return_value="Test Response") as mock_method:
+        prompt = "Test Prompt"
+        response = antrophic_llm.get_llm_model_answer(prompt)
+        assert response == "Test Response"
+        mock_method.assert_called_once_with(prompt=prompt, config=antrophic_llm.config)
+
+
+def test_get_answer(antrophic_llm):
+    with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
+        mock_chat_instance = mock_chat.return_value
+        mock_chat_instance.return_value = MagicMock(content="Test Response")
+
+        prompt = "Test Prompt"
+        response = antrophic_llm._get_answer(prompt, antrophic_llm.config)
+
+        assert response == "Test Response"
+        mock_chat.assert_called_once_with(
+            temperature=antrophic_llm.config.temperature, model=antrophic_llm.config.model
+        )
+        mock_chat_instance.assert_called_once_with(
+            antrophic_llm._get_messages(prompt, system_prompt=antrophic_llm.config.system_prompt)
+        )
+
+
+def test_get_messages(antrophic_llm):
+    prompt = "Test Prompt"
+    system_prompt = "Test System Prompt"
+    messages = antrophic_llm._get_messages(prompt, system_prompt)
+    assert messages == [
+        SystemMessage(content="Test System Prompt", additional_kwargs={}),
+        HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
+    ]
+
+
+def test_get_answer_max_tokens_is_provided(antrophic_llm, caplog):
+    with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
+        mock_chat_instance = mock_chat.return_value
+        mock_chat_instance.return_value = MagicMock(content="Test Response")
+
+        prompt = "Test Prompt"
+        config = antrophic_llm.config
+        config.max_tokens = 500
+
+        response = antrophic_llm._get_answer(prompt, config)
+
+        assert response == "Test Response"
+        mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
+
+        assert "Config option `max_tokens` is not supported by this model." in caplog.text

+ 91 - 0
tests/llm/test_azure_openai.py

@@ -0,0 +1,91 @@
+import pytest
+from unittest.mock import MagicMock, patch
+from embedchain.llm.azure_openai import AzureOpenAILlm
+from embedchain.config import BaseLlmConfig
+from langchain.schema import HumanMessage, SystemMessage
+
+
+@pytest.fixture
+def azure_openai_llm():
+    config = BaseLlmConfig(
+        deployment_name="azure_deployment",
+        temperature=0.7,
+        model="gpt-3.5-turbo",
+        max_tokens=50,
+        system_prompt="System Prompt",
+    )
+    return AzureOpenAILlm(config)
+
+
+def test_get_llm_model_answer(azure_openai_llm):
+    with patch.object(AzureOpenAILlm, "_get_answer", return_value="Test Response") as mock_method:
+        prompt = "Test Prompt"
+        response = azure_openai_llm.get_llm_model_answer(prompt)
+        assert response == "Test Response"
+        mock_method.assert_called_once_with(prompt=prompt, config=azure_openai_llm.config)
+
+
+def test_get_answer(azure_openai_llm):
+    with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
+        mock_chat_instance = mock_chat.return_value
+        mock_chat_instance.return_value = MagicMock(content="Test Response")
+
+        prompt = "Test Prompt"
+        response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config)
+
+        assert response == "Test Response"
+        mock_chat.assert_called_once_with(
+            deployment_name=azure_openai_llm.config.deployment_name,
+            openai_api_version="2023-05-15",
+            model_name=azure_openai_llm.config.model or "gpt-3.5-turbo",
+            temperature=azure_openai_llm.config.temperature,
+            max_tokens=azure_openai_llm.config.max_tokens,
+            streaming=azure_openai_llm.config.stream,
+        )
+        mock_chat_instance.assert_called_once_with(
+            azure_openai_llm._get_messages(prompt, system_prompt=azure_openai_llm.config.system_prompt)
+        )
+
+
+def test_get_messages(azure_openai_llm):
+    prompt = "Test Prompt"
+    system_prompt = "Test System Prompt"
+    messages = azure_openai_llm._get_messages(prompt, system_prompt)
+    assert messages == [
+        SystemMessage(content="Test System Prompt", additional_kwargs={}),
+        HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
+    ]
+
+
+def test_get_answer_top_p_is_provided(azure_openai_llm, caplog):
+    with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
+        mock_chat_instance = mock_chat.return_value
+        mock_chat_instance.return_value = MagicMock(content="Test Response")
+
+        prompt = "Test Prompt"
+        config = azure_openai_llm.config
+        config.top_p = 0.5
+
+        response = azure_openai_llm._get_answer(prompt, config)
+
+        assert response == "Test Response"
+        mock_chat.assert_called_once_with(
+            deployment_name=config.deployment_name,
+            openai_api_version="2023-05-15",
+            model_name=config.model or "gpt-3.5-turbo",
+            temperature=config.temperature,
+            max_tokens=config.max_tokens,
+            streaming=config.stream,
+        )
+        mock_chat_instance.assert_called_once_with(
+            azure_openai_llm._get_messages(prompt, system_prompt=config.system_prompt)
+        )
+
+        assert "Config option `top_p` is not supported by this model." in caplog.text
+
+
+def test_when_no_deployment_name_provided():
+    config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt")
+    with pytest.raises(ValueError):
+        llm = AzureOpenAILlm(config)
+        llm.get_llm_model_answer("Test Prompt")

+ 63 - 0
tests/llm/test_vertex_ai.py

@@ -0,0 +1,63 @@
+import pytest
+from unittest.mock import MagicMock, patch
+from embedchain.llm.vertex_ai import VertexAiLlm
+from embedchain.config import BaseLlmConfig
+from langchain.schema import HumanMessage, SystemMessage
+
+
+@pytest.fixture
+def vertexai_llm():
+    config = BaseLlmConfig(temperature=0.6, model="vertexai_model", system_prompt="System Prompt")
+    return VertexAiLlm(config)
+
+
+def test_get_llm_model_answer(vertexai_llm):
+    with patch.object(VertexAiLlm, "_get_answer", return_value="Test Response") as mock_method:
+        prompt = "Test Prompt"
+        response = vertexai_llm.get_llm_model_answer(prompt)
+        assert response == "Test Response"
+        mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config)
+
+
+def test_get_answer_with_warning(vertexai_llm, caplog):
+    with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
+        mock_chat_instance = mock_chat.return_value
+        mock_chat_instance.return_value = MagicMock(content="Test Response")
+
+        prompt = "Test Prompt"
+        config = vertexai_llm.config
+        config.top_p = 0.5
+
+        response = vertexai_llm._get_answer(prompt, config)
+
+        assert response == "Test Response"
+        mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
+
+        assert "Config option `top_p` is not supported by this model." in caplog.text
+
+
+def test_get_answer_no_warning(vertexai_llm, caplog):
+    with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
+        mock_chat_instance = mock_chat.return_value
+        mock_chat_instance.return_value = MagicMock(content="Test Response")
+
+        prompt = "Test Prompt"
+        config = vertexai_llm.config
+        config.top_p = 1.0
+
+        response = vertexai_llm._get_answer(prompt, config)
+
+        assert response == "Test Response"
+        mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
+
+        assert "Config option `top_p` is not supported by this model." not in caplog.text
+
+
+def test_get_messages(vertexai_llm):
+    prompt = "Test Prompt"
+    system_prompt = "Test System Prompt"
+    messages = vertexai_llm._get_messages(prompt, system_prompt)
+    assert messages == [
+        SystemMessage(content="Test System Prompt", additional_kwargs={}),
+        HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
+    ]