Преглед изворни кода

[Feature] Batch uploading in chromadb (#814)

Rupesh Bansal пре 1 година
родитељ
комит
d8a7d71344
2 измењених фајлова са 63 додато и 9 уклоњено
  1. 29 9
      embedchain/vectordb/chroma.py
  2. 34 0
      tests/vectordb/test_chroma_db.py

+ 29 - 9
embedchain/vectordb/chroma.py

@@ -25,6 +25,8 @@ except RuntimeError:
 class ChromaDB(BaseVectorDB):
     """Vector database using ChromaDB."""
 
+    BATCH_SIZE = 100
+
     def __init__(self, config: Optional[ChromaDbConfig] = None):
         """Initialize a new ChromaDB instance
 
@@ -123,10 +125,6 @@ class ChromaDB(BaseVectorDB):
             args["limit"] = limit
         return self.collection.get(**args)
 
-    def get_advanced(self, where):
-        where_clause = self._generate_where_clause(where)
-        return self.collection.get(where=where_clause, limit=1)
-
     def add(
         self,
         embeddings: List[List[float]],
@@ -149,10 +147,31 @@ class ChromaDB(BaseVectorDB):
         :param skip_embedding: Optional. If True, then the embeddings are assumed to be already generated.
         :type skip_embedding: bool
         """
-        if skip_embedding:
-            self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
-        else:
-            self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
+        size = len(documents)
+        if skip_embedding and (embeddings is None or len(embeddings) != len(documents)):
+            raise ValueError("Cannot add documents to chromadb with inconsistent embeddings")
+
+        if len(documents) != size or len(metadatas) != size or len(ids) != size:
+            raise ValueError(
+                "Cannot add documents to chromadb with inconsistent sizes. Documents size: {}, Metadata size: {},"
+                " Ids size: {}".format(len(documents), len(metadatas), len(ids))
+            )
+
+        for i in range(0, len(documents), self.BATCH_SIZE):
+            print("Inserting batches from {} to {} in chromadb".format(i, min(len(documents), i + self.BATCH_SIZE)))
+            if skip_embedding:
+                self.collection.add(
+                    embeddings=embeddings[i : i + self.BATCH_SIZE],
+                    documents=documents[i : i + self.BATCH_SIZE],
+                    metadatas=metadatas[i : i + self.BATCH_SIZE],
+                    ids=ids[i : i + self.BATCH_SIZE],
+                )
+            else:
+                self.collection.add(
+                    documents=documents[i : i + self.BATCH_SIZE],
+                    metadatas=metadatas[i : i + self.BATCH_SIZE],
+                    ids=ids[i : i + self.BATCH_SIZE],
+                )
 
     def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
         """
@@ -208,7 +227,8 @@ class ChromaDB(BaseVectorDB):
         except InvalidDimensionException as e:
             raise InvalidDimensionException(
                 e.message()
-                + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database."  # noqa E501
+                + ". This is commonly a side-effect when an embedding function, different from the one used to add the"
+                " embeddings, is used to retrieve an embedding from the database."
             ) from None
         results_formatted = self._format_result(result)
         contents = [result[0].page_content for result in results_formatted]

+ 34 - 0
tests/vectordb/test_chroma_db.py

@@ -228,6 +228,40 @@ class TestChromaDbCollection(unittest.TestCase):
         expected_value = ["document"]
         self.assertEqual(data, expected_value)
 
+    def test_add_with_invalid_inputs(self):
+        """
+        Test add fails with invalid inputs
+        """
+        # Start with a clean app
+        self.app_with_settings.reset()
+        # app = App(config=AppConfig(collect_metrics=False), db=db)
+
+        # Collection should be empty when created
+        self.assertEqual(self.app_with_settings.db.count(), 0)
+
+        with self.assertRaises(ValueError):
+            self.app_with_settings.db.add(
+                embeddings=[[0, 0, 0]],
+                documents=["document", "document2"],
+                metadatas=[{"value": "somevalue"}],
+                ids=["id"],
+                skip_embedding=True,
+            )
+        # After adding, should contain no item
+        self.assertEqual(self.app_with_settings.db.count(), 0)
+
+        with self.assertRaises(ValueError):
+            self.app_with_settings.db.add(
+                embeddings=None,
+                documents=["document", "document2"],
+                metadatas=[{"value": "somevalue"}],
+                ids=["id"],
+                skip_embedding=True,
+            )
+
+        # After adding, should contain no item
+        self.assertEqual(self.app_with_settings.db.count(), 0)
+
     def test_collections_are_persistent(self):
         """
         Test that a collection can be picked up later.