Jelajahi Sumber

[Bug fix] Fix embedding issue for opensearch and some other vector databases (#1163)

Deshraj Yadav 1 tahun lalu
induk
melakukan
862ff6cca6

+ 1 - 1
embedchain/chunkers/base_chunker.py

@@ -27,7 +27,7 @@ class BaseChunker(JSONSerializable):
         chunk_ids = []
         id_map = {}
         min_chunk_size = config.min_chunk_size if config is not None else 1
-        logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
+        logging.info(f"Skipping chunks smaller than {min_chunk_size} characters")
         data_result = loader.load_data(src)
         data_records = data_result["data"]
         doc_id = data_result["doc_id"]

+ 2 - 8
embedchain/embedchain.py

@@ -369,7 +369,7 @@ class EmbedChain(JSONSerializable):
         metadatas = embeddings_data["metadatas"]
         ids = embeddings_data["ids"]
         new_doc_id = embeddings_data["doc_id"]
-        embeddings = embeddings_data.get("embeddings")
+
         if existing_doc_id and existing_doc_id == new_doc_id:
             print("Doc content has not changed. Skipping creating chunks and embeddings")
             return [], [], [], 0
@@ -433,13 +433,7 @@ class EmbedChain(JSONSerializable):
         # Count before, to calculate a delta in the end.
         chunks_before_addition = self.db.count()
 
-        self.db.add(
-            embeddings=embeddings,
-            documents=documents,
-            metadatas=metadatas,
-            ids=ids,
-            **kwargs,
-        )
+        self.db.add(documents=documents, metadatas=metadatas, ids=ids, **kwargs)
         count_new_chunks = self.db.count() - chunks_before_addition
 
         print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")

+ 0 - 4
embedchain/vectordb/chroma.py

@@ -129,17 +129,13 @@ class ChromaDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: list[list[float]],
         documents: list[str],
         metadatas: list[object],
         ids: list[str],
-        **kwargs: Optional[dict[str, Any]],
     ) -> Any:
         """
         Add vectors to chroma database
 
-        :param embeddings: list of embeddings to add
-        :type embeddings: list[list[str]]
         :param documents: Documents
         :type documents: list[str]
         :param metadatas: Metadatas

+ 0 - 3
embedchain/vectordb/elasticsearch.py

@@ -110,7 +110,6 @@ class ElasticsearchDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: list[list[float]],
         documents: list[str],
         metadatas: list[object],
         ids: list[str],
@@ -118,8 +117,6 @@ class ElasticsearchDB(BaseVectorDB):
     ) -> Any:
         """
         add data in vector database
-        :param embeddings: list of embeddings to add
-        :type embeddings: list[list[str]]
         :param documents: list of texts to add
         :type documents: list[str]
         :param metadatas: list of metadata associated with docs

+ 4 - 16
embedchain/vectordb/opensearch.py

@@ -114,22 +114,10 @@ class OpenSearchDB(BaseVectorDB):
             result["metadatas"].append({"doc_id": doc_id})
         return result
 
-    def add(
-        self,
-        embeddings: list[list[str]],
-        documents: list[str],
-        metadatas: list[object],
-        ids: list[str],
-        **kwargs: Optional[dict[str, any]],
-    ):
-        """Add data in vector database.
-
-        Args:
-            embeddings (list[list[str]]): list of embeddings to add.
-            documents (list[str]): list of texts to add.
-            metadatas (list[object]): list of metadata associated with docs.
-            ids (list[str]): IDs of docs.
-        """
+    def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
+        """Adds documents to the opensearch index"""
+
+        embeddings = self.embedder.embedding_fn(documents)
         for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
             batch_end = batch_start + self.BATCH_SIZE
             batch_documents = documents[batch_start:batch_end]

+ 0 - 1
embedchain/vectordb/pinecone.py

@@ -88,7 +88,6 @@ class PineconeDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: list[list[float]],
         documents: list[str],
         metadatas: list[object],
         ids: list[str],

+ 0 - 3
embedchain/vectordb/qdrant.py

@@ -122,15 +122,12 @@ class QdrantDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: list[list[float]],
         documents: list[str],
         metadatas: list[object],
         ids: list[str],
         **kwargs: Optional[dict[str, any]],
     ):
         """add data in vector database
-        :param embeddings: list of embeddings for the corresponding documents to be added
-        :type documents: list[list[float]]
         :param documents: list of texts to add
         :type documents: list[str]
         :param metadatas: list of metadata associated with docs

+ 3 - 17
embedchain/vectordb/weaviate.py

@@ -1,6 +1,6 @@
 import copy
 import os
-from typing import Any, Optional, Union
+from typing import Optional, Union
 
 try:
     import weaviate
@@ -151,17 +151,8 @@ class WeaviateDB(BaseVectorDB):
 
         return {"ids": existing_ids}
 
-    def add(
-        self,
-        embeddings: list[list[float]],
-        documents: list[str],
-        metadatas: list[object],
-        ids: list[str],
-        **kwargs: Optional[dict[str, any]],
-    ):
+    def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
         """add data in vector database
-        :param embeddings: list of embeddings for the corresponding documents to be added
-        :type documents: list[list[float]]
         :param documents: list of texts to add
         :type documents: list[str]
         :param metadatas: list of metadata associated with docs
@@ -191,12 +182,7 @@ class WeaviateDB(BaseVectorDB):
                 )
 
     def query(
-        self,
-        input_query: list[str],
-        n_results: int,
-        where: dict[str, any],
-        citations: bool = False,
-        **kwargs: Optional[dict[str, Any]],
+        self, input_query: list[str], n_results: int, where: dict[str, any], citations: bool = False
     ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         query contents from vector database based on vector similarity

+ 0 - 1
embedchain/vectordb/zilliz.py

@@ -108,7 +108,6 @@ class ZillizVectorDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: list[list[float]],
         documents: list[str],
         metadatas: list[object],
         ids: list[str],

+ 2 - 3
tests/vectordb/test_elasticsearch_db.py

@@ -28,14 +28,13 @@ class TestEsDB(unittest.TestCase):
         # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
         self.assertEqual(self.db.client, mock_client.return_value)
 
-        # Create some dummy data.
-        embeddings = [[1, 2, 3], [4, 5, 6]]
+        # Create some dummy data
         documents = ["This is a document.", "This is another document."]
         metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
         ids = ["doc_1", "doc_2"]
 
         # Add the data to the database.
-        self.db.add(embeddings, documents, metadatas, ids)
+        self.db.add(documents, metadatas, ids)
 
         search_response = {
             "hits": {

+ 3 - 3
tests/vectordb/test_pinecone.py

@@ -43,8 +43,8 @@ class TestPinecone:
         embedding_function = mock.Mock()
         base_embedder = BaseEmbedder()
         base_embedder.set_embedding_fn(embedding_function)
-        vectors = [[0, 0, 0], [1, 1, 1]]
-        embedding_function.return_value = vectors
+        embedding_function.return_value = [[0, 0, 0], [1, 1, 1]]
+
         # Create a PineconeDb instance
         db = PineconeDB()
         app_config = AppConfig(collect_metrics=False)
@@ -54,7 +54,7 @@ class TestPinecone:
         documents = ["This is a document.", "This is another document."]
         metadatas = [{}, {}]
         ids = ["doc1", "doc2"]
-        db.add(vectors, documents, metadatas, ids)
+        db.add(documents, metadatas, ids)
 
         expected_pinecone_upsert_args = [
             {"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},

+ 2 - 3
tests/vectordb/test_qdrant.py

@@ -75,11 +75,10 @@ class TestQdrantDB(unittest.TestCase):
         app_config = AppConfig(collect_metrics=False)
         App(config=app_config, db=db, embedding_model=embedder)
 
-        embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
         documents = ["This is a test document.", "This is another test document."]
         metadatas = [{}, {}]
         ids = ["123", "456"]
-        db.add(embeddings, documents, metadatas, ids)
+        db.add(documents, metadatas, ids)
         qdrant_client_mock.return_value.upsert.assert_called_once_with(
             collection_name="embedchain-store-1526",
             points=Batch(
@@ -96,7 +95,7 @@ class TestQdrantDB(unittest.TestCase):
                         "metadata": {"text": "This is another test document."},
                     },
                 ],
-                vectors=embeddings,
+                vectors=[[1, 2, 3], [4, 5, 6]],
             ),
         )
 

+ 24 - 33
tests/vectordb/test_weaviate.py

@@ -29,7 +29,7 @@ class TestWeaviateDb(unittest.TestCase):
         weaviate_client_schema_mock.exists.return_value = False
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Weaviate instance
@@ -40,7 +40,7 @@ class TestWeaviateDb(unittest.TestCase):
         expected_class_obj = {
             "classes": [
                 {
-                    "class": "Embedchain_store_1526",
+                    "class": "Embedchain_store_1536",
                     "vectorizer": "none",
                     "properties": [
                         {
@@ -53,12 +53,12 @@ class TestWeaviateDb(unittest.TestCase):
                         },
                         {
                             "name": "metadata",
-                            "dataType": ["Embedchain_store_1526_metadata"],
+                            "dataType": ["Embedchain_store_1536_metadata"],
                         },
                     ],
                 },
                 {
-                    "class": "Embedchain_store_1526_metadata",
+                    "class": "Embedchain_store_1536_metadata",
                     "vectorizer": "none",
                     "properties": [
                         {
@@ -88,7 +88,7 @@ class TestWeaviateDb(unittest.TestCase):
 
         # Assert that the Weaviate client was initialized
         weaviate_mock.Client.assert_called_once()
-        self.assertEqual(db.index_name, "Embedchain_store_1526")
+        self.assertEqual(db.index_name, "Embedchain_store_1536")
         weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
 
     @patch("embedchain.vectordb.weaviate.weaviate")
@@ -97,7 +97,7 @@ class TestWeaviateDb(unittest.TestCase):
         weaviate_client_mock = weaviate_mock.Client.return_value
 
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Weaviate instance
@@ -117,7 +117,7 @@ class TestWeaviateDb(unittest.TestCase):
 
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Weaviate instance
@@ -126,30 +126,21 @@ class TestWeaviateDb(unittest.TestCase):
         App(config=app_config, db=db, embedding_model=embedder)
         db.BATCH_SIZE = 1
 
-        embeddings = [[1, 2, 3], [4, 5, 6]]
-        documents = ["This is a test document.", "This is another test document."]
-        metadatas = [None, None]
-        ids = ["123", "456"]
-        db.add(embeddings, documents, metadatas, ids)
+        documents = ["This is test document"]
+        metadatas = [None]
+        ids = ["id_1"]
+        db.add(documents, metadatas, ids)
 
         # Check if the document was added to the database.
         weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
         weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
-            data_object={"text": documents[0]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[0]
-        )
-        weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
-            data_object={"text": documents[1]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[1]
+            data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
         )
 
         weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
-            data_object={"identifier": ids[0], "text": documents[0]},
-            class_name="Embedchain_store_1526",
-            vector=embeddings[0],
-        )
-        weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
-            data_object={"identifier": ids[1], "text": documents[1]},
-            class_name="Embedchain_store_1526",
-            vector=embeddings[1],
+            data_object={"text": documents[0]},
+            class_name="Embedchain_store_1536_metadata",
+            vector=[1, 2, 3],
         )
 
     @patch("embedchain.vectordb.weaviate.weaviate")
@@ -161,7 +152,7 @@ class TestWeaviateDb(unittest.TestCase):
 
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Weaviate instance
@@ -172,7 +163,7 @@ class TestWeaviateDb(unittest.TestCase):
         # Query for the document.
         db.query(input_query=["This is a test document."], n_results=1, where={})
 
-        weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
+        weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
         weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
 
     @patch("embedchain.vectordb.weaviate.weaviate")
@@ -185,7 +176,7 @@ class TestWeaviateDb(unittest.TestCase):
 
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Weaviate instance
@@ -196,9 +187,9 @@ class TestWeaviateDb(unittest.TestCase):
         # Query for the document.
         db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
 
-        weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
+        weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
         weaviate_client_query_get_mock.with_where.assert_called_once_with(
-            {"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
+            {"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"}
         )
         weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
 
@@ -210,7 +201,7 @@ class TestWeaviateDb(unittest.TestCase):
 
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Weaviate instance
@@ -222,7 +213,7 @@ class TestWeaviateDb(unittest.TestCase):
         db.reset()
 
         weaviate_client_batch_mock.delete_objects.assert_called_once_with(
-            "Embedchain_store_1526", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
+            "Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
         )
 
     @patch("embedchain.vectordb.weaviate.weaviate")
@@ -233,7 +224,7 @@ class TestWeaviateDb(unittest.TestCase):
 
         # Set the embedder
         embedder = BaseEmbedder()
-        embedder.set_vector_dimension(1526)
+        embedder.set_vector_dimension(1536)
         embedder.set_embedding_fn(mock_embedding_fn)
 
         # Create a Weaviate instance
@@ -244,4 +235,4 @@ class TestWeaviateDb(unittest.TestCase):
         # Reset the database.
         db.count()
 
-        weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1526")
+        weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")