Quellcode durchsuchen

Fix skipped tests (#1385)

Dev Khant vor 1 Jahr
Ursprung
Commit
827d63d115

+ 17 - 13
tests/llm/test_vertex_ai.py

@@ -2,12 +2,17 @@ from unittest.mock import MagicMock, patch
 
 import pytest
 from langchain.schema import HumanMessage, SystemMessage
-from langchain_google_vertexai import ChatVertexAI
 
 from embedchain.config import BaseLlmConfig
+from embedchain.core.db.database import database_manager
 from embedchain.llm.vertex_ai import VertexAILlm
 
 
+@pytest.fixture(autouse=True)
+def setup_database():
+    database_manager.setup_engine()
+
+
 @pytest.fixture
 def vertexai_llm():
     config = BaseLlmConfig(temperature=0.6, model="chat-bison")
@@ -22,19 +27,18 @@ def test_get_llm_model_answer(vertexai_llm):
         mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config)
 
 
-@pytest.mark.skip(
-    reason="Requires mocking of Google Console Auth. Revisit later since don't want to block users right now."
-)
-def test_get_answer(vertexai_llm, caplog):
-    with patch.object(ChatVertexAI, "invoke", return_value=MagicMock(content="Test Response")) as mock_method:
-        config = vertexai_llm.config
-        prompt = "Test Prompt"
-        messages = vertexai_llm._get_messages(prompt)
-        response = vertexai_llm._get_answer(prompt, config)
-        mock_method.assert_called_once_with(messages)
+@patch("embedchain.llm.vertex_ai.ChatVertexAI")
+def test_get_answer(mock_chat_vertexai, vertexai_llm, caplog):
+    mock_chat_vertexai.return_value.invoke.return_value = MagicMock(content="Test Response")
+
+    config = vertexai_llm.config
+    prompt = "Test Prompt"
+    messages = vertexai_llm._get_messages(prompt)
+    response = vertexai_llm._get_answer(prompt, config)
+    mock_chat_vertexai.return_value.invoke.assert_called_once_with(messages)
 
-        assert response == "Test Response"  # Assertion corrected
-        assert "Config option `top_p` is not supported by this model." not in caplog.text
+    assert response == "Test Response"  # Assertion corrected
+    assert "Config option `top_p` is not supported by this model." not in caplog.text
 
 
 def test_get_messages(vertexai_llm):

+ 1 - 3
tests/telemetry/test_posthog.py

@@ -1,8 +1,6 @@
 import logging
 import os
 
-import pytest
-
 from embedchain.telemetry.posthog import AnonymousTelemetry
 
 
@@ -54,7 +52,6 @@ class TestAnonymousTelemetry:
             properties,
         )
 
-    @pytest.mark.skip(reason="Logging setup needs to be fixed to make this test to work")
     def test_capture_with_exception(self, mocker, caplog):
         os.environ["EC_TELEMETRY"] = "true"
         mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog")
@@ -65,3 +62,4 @@ class TestAnonymousTelemetry:
         with caplog.at_level(logging.ERROR):
             telemetry.capture(event_name, properties)
         assert "Failed to send telemetry event" in caplog.text
+        caplog.clear()

+ 19 - 14
tests/vectordb/test_chroma_db.py

@@ -34,15 +34,16 @@ def cleanup_db():
         print("Error: %s - %s." % (e.filename, e.strerror))
 
 
-@pytest.mark.skip(reason="ChromaDB client needs to be mocked")
-def test_chroma_db_init_with_host_and_port(chroma_db):
-    settings = chroma_db.client.get_settings()
-    assert settings.chroma_server_host == "test-host"
-    assert settings.chroma_server_http_port == "1234"
+@patch("embedchain.vectordb.chroma.chromadb.Client")
+def test_chroma_db_init_with_host_and_port(mock_client):
+    chroma_db = ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))  # noqa
+    called_settings: Settings = mock_client.call_args[0][0]
+    assert called_settings.chroma_server_host == "test-host"
+    assert called_settings.chroma_server_http_port == "1234"
 
 
-@pytest.mark.skip(reason="ChromaDB client needs to be mocked")
-def test_chroma_db_init_with_basic_auth():
+@patch("embedchain.vectordb.chroma.chromadb.Client")
+def test_chroma_db_init_with_basic_auth(mock_client):
     chroma_config = {
         "host": "test-host",
         "port": "1234",
@@ -52,12 +53,17 @@ def test_chroma_db_init_with_basic_auth():
         },
     }
 
-    db = ChromaDB(config=ChromaDbConfig(**chroma_config))
-    settings = db.client.get_settings()
-    assert settings.chroma_server_host == "test-host"
-    assert settings.chroma_server_http_port == "1234"
-    assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
-    assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
+    ChromaDB(config=ChromaDbConfig(**chroma_config))
+    called_settings: Settings = mock_client.call_args[0][0]
+    assert called_settings.chroma_server_host == "test-host"
+    assert called_settings.chroma_server_http_port == "1234"
+    assert (
+        called_settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
+    )
+    assert (
+        called_settings.chroma_client_auth_credentials
+        == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
+    )
 
 
 @patch("embedchain.vectordb.chroma.chromadb.Client")
@@ -84,7 +90,6 @@ def test_app_init_with_host_and_port_none(mock_client):
     assert called_settings.chroma_server_http_port is None
 
 
-@pytest.mark.skip(reason="Logging setup needs to be fixed to make this test to work")
 def test_chroma_db_duplicates_throw_warning(caplog):
     db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
     app = App(config=AppConfig(collect_metrics=False), db=db)

+ 1 - 3
tests/vectordb/test_qdrant.py

@@ -1,7 +1,6 @@
 import unittest
 import uuid
 
-import pytest
 from mock import patch
 from qdrant_client.http import models
 from qdrant_client.http.models import Batch
@@ -61,7 +60,6 @@ class TestQdrantDB(unittest.TestCase):
         resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
         self.assertEqual(resp2, {"ids": [], "metadatas": []})
 
-    @pytest.mark.skip(reason="Investigate the issue with the test case.")
     @patch("embedchain.vectordb.qdrant.QdrantClient")
     @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
     def test_add(self, uuid_mock, qdrant_client_mock):
@@ -84,7 +82,7 @@ class TestQdrantDB(unittest.TestCase):
         qdrant_client_mock.return_value.upsert.assert_called_once_with(
             collection_name="embedchain-store-1536",
             points=Batch(
-                ids=["abc", "def"],
+                ids=["123", "456"],
                 payloads=[
                     {
                         "identifier": "123",