123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- import unittest
- from unittest.mock import patch
- from embedchain import App
- from embedchain.config import AppConfig
- from embedchain.config.vectordb.pinecone import PineconeDBConfig
- from embedchain.embedder.base import BaseEmbedder
- from embedchain.vectordb.weaviate import WeaviateDB
- def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
- """A mock embedding function."""
- return [[1, 2, 3], [4, 5, 6]]
- class TestWeaviateDb(unittest.TestCase):
- def test_incorrect_config_throws_error(self):
- """Test the init method of the WeaviateDb class throws error for incorrect config"""
- with self.assertRaises(TypeError):
- WeaviateDB(config=PineconeDBConfig())
- @patch("embedchain.vectordb.weaviate.weaviate")
- def test_initialize(self, weaviate_mock):
- """Test the init method of the WeaviateDb class."""
- weaviate_client_mock = weaviate_mock.Client.return_value
- weaviate_client_schema_mock = weaviate_client_mock.schema
- # Mock that schema doesn't already exist so that a new schema is created
- weaviate_client_schema_mock.exists.return_value = False
- # Set the embedder
- embedder = BaseEmbedder()
- embedder.set_vector_dimension(1536)
- embedder.set_embedding_fn(mock_embedding_fn)
- # Create a Weaviate instance
- db = WeaviateDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=embedder)
- expected_class_obj = {
- "classes": [
- {
- "class": "Embedchain_store_1536",
- "vectorizer": "none",
- "properties": [
- {
- "name": "identifier",
- "dataType": ["text"],
- },
- {
- "name": "text",
- "dataType": ["text"],
- },
- {
- "name": "metadata",
- "dataType": ["Embedchain_store_1536_metadata"],
- },
- ],
- },
- {
- "class": "Embedchain_store_1536_metadata",
- "vectorizer": "none",
- "properties": [
- {
- "name": "data_type",
- "dataType": ["text"],
- },
- {
- "name": "doc_id",
- "dataType": ["text"],
- },
- {
- "name": "url",
- "dataType": ["text"],
- },
- {
- "name": "hash",
- "dataType": ["text"],
- },
- {
- "name": "app_id",
- "dataType": ["text"],
- },
- ],
- },
- ]
- }
- # Assert that the Weaviate client was initialized
- weaviate_mock.Client.assert_called_once()
- self.assertEqual(db.index_name, "Embedchain_store_1536")
- weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
- @patch("embedchain.vectordb.weaviate.weaviate")
- def test_get_or_create_db(self, weaviate_mock):
- """Test the _get_or_create_db method of the WeaviateDb class."""
- weaviate_client_mock = weaviate_mock.Client.return_value
- embedder = BaseEmbedder()
- embedder.set_vector_dimension(1536)
- embedder.set_embedding_fn(mock_embedding_fn)
- # Create a Weaviate instance
- db = WeaviateDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=embedder)
- expected_client = db._get_or_create_db()
- self.assertEqual(expected_client, weaviate_client_mock)
- @patch("embedchain.vectordb.weaviate.weaviate")
- def test_add(self, weaviate_mock):
- """Test the add method of the WeaviateDb class."""
- weaviate_client_mock = weaviate_mock.Client.return_value
- weaviate_client_batch_mock = weaviate_client_mock.batch
- weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
- # Set the embedder
- embedder = BaseEmbedder()
- embedder.set_vector_dimension(1536)
- embedder.set_embedding_fn(mock_embedding_fn)
- # Create a Weaviate instance
- db = WeaviateDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=embedder)
- db.BATCH_SIZE = 1
- documents = ["This is test document"]
- metadatas = [None]
- ids = ["id_1"]
- db.add(documents, metadatas, ids)
- # Check if the document was added to the database.
- weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
- weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
- data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
- )
- weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
- data_object={"text": documents[0]},
- class_name="Embedchain_store_1536_metadata",
- vector=[1, 2, 3],
- )
- @patch("embedchain.vectordb.weaviate.weaviate")
- def test_query_without_where(self, weaviate_mock):
- """Test the query method of the WeaviateDb class."""
- weaviate_client_mock = weaviate_mock.Client.return_value
- weaviate_client_query_mock = weaviate_client_mock.query
- weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
- # Set the embedder
- embedder = BaseEmbedder()
- embedder.set_vector_dimension(1536)
- embedder.set_embedding_fn(mock_embedding_fn)
- # Create a Weaviate instance
- db = WeaviateDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=embedder)
- # Query for the document.
- db.query(input_query="This is a test document.", n_results=1, where={})
- weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
- weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
- @patch("embedchain.vectordb.weaviate.weaviate")
- def test_query_with_where(self, weaviate_mock):
- """Test the query method of the WeaviateDb class."""
- weaviate_client_mock = weaviate_mock.Client.return_value
- weaviate_client_query_mock = weaviate_client_mock.query
- weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
- weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
- # Set the embedder
- embedder = BaseEmbedder()
- embedder.set_vector_dimension(1536)
- embedder.set_embedding_fn(mock_embedding_fn)
- # Create a Weaviate instance
- db = WeaviateDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=embedder)
- # Query for the document.
- db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
- weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
- weaviate_client_query_get_mock.with_where.assert_called_once_with(
- {"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"}
- )
- weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
- @patch("embedchain.vectordb.weaviate.weaviate")
- def test_reset(self, weaviate_mock):
- """Test the reset method of the WeaviateDb class."""
- weaviate_client_mock = weaviate_mock.Client.return_value
- weaviate_client_batch_mock = weaviate_client_mock.batch
- # Set the embedder
- embedder = BaseEmbedder()
- embedder.set_vector_dimension(1536)
- embedder.set_embedding_fn(mock_embedding_fn)
- # Create a Weaviate instance
- db = WeaviateDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=embedder)
- # Reset the database.
- db.reset()
- weaviate_client_batch_mock.delete_objects.assert_called_once_with(
- "Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
- )
- @patch("embedchain.vectordb.weaviate.weaviate")
- def test_count(self, weaviate_mock):
- """Test the reset method of the WeaviateDb class."""
- weaviate_client_mock = weaviate_mock.Client.return_value
- weaviate_client_query = weaviate_client_mock.query
- # Set the embedder
- embedder = BaseEmbedder()
- embedder.set_vector_dimension(1536)
- embedder.set_embedding_fn(mock_embedding_fn)
- # Create a Weaviate instance
- db = WeaviateDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=embedder)
- # Reset the database.
- db.count()
- weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")
|