소스 검색

Add batch_size in config for VectorDB (#1448)

Dev Khant 1 년 전
부모
커밋
0a78198bb5

+ 1 - 0
docs/api-reference/advanced/configuration.mdx

@@ -217,6 +217,7 @@ Alright, let's dive into what each key means in the yaml config above:
         - `collection_name` (String): The initial collection name for the vectordb, set to 'full-stack-app'.
         - `dir` (String): The directory for the local database, set to 'db'.
         - `allow_reset` (Boolean): Indicates whether resetting the vectordb is allowed, set to true.
+        - `batch_size` (Integer): The batch size for docs insertion in vectordb, defaults to `100`
     <Note>We recommend you to checkout vectordb specific config [here](https://docs.embedchain.ai/components/vector-databases)</Note>
 4. `embedder` Section:
     - `provider` (String): The provider for the embedder, set to 'openai'. You can find the full list of embedding model providers in [our docs](/components/embedding-models).

+ 4 - 0
embedchain/config/vectordb/base.py

@@ -10,6 +10,7 @@ class BaseVectorDbConfig(BaseConfig):
         dir: str = "db",
         host: Optional[str] = None,
         port: Optional[str] = None,
+        batch_size: Optional[int] = 100,
         **kwargs,
     ):
         """
@@ -23,6 +24,8 @@ class BaseVectorDbConfig(BaseConfig):
         :type host: Optional[str], optional
         :param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
         :type port: Optional[str], optional
+        :param batch_size: Number of items to insert in one batch, defaults to 100
+        :type batch_size: Optional[int], optional
         :param kwargs: Additional keyword arguments
         :type kwargs: dict
         """
@@ -30,6 +33,7 @@ class BaseVectorDbConfig(BaseConfig):
         self.dir = dir
         self.host = host
         self.port = port
+        self.batch_size = batch_size
         # Assign additional keyword arguments
         if kwargs:
             for key, value in kwargs.items():

+ 4 - 6
embedchain/vectordb/chroma.py

@@ -29,8 +29,6 @@ logger = logging.getLogger(__name__)
 class ChromaDB(BaseVectorDB):
     """Vector database using ChromaDB."""
 
-    BATCH_SIZE = 100
-
     def __init__(self, config: Optional[ChromaDbConfig] = None):
         """Initialize a new ChromaDB instance
 
@@ -155,11 +153,11 @@ class ChromaDB(BaseVectorDB):
                 " Ids size: {}".format(len(documents), len(metadatas), len(ids))
             )
 
-        for i in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in chromadb"):
+        for i in tqdm(range(0, len(documents), self.config.batch_size), desc="Inserting batches in chromadb"):
             self.collection.add(
-                documents=documents[i : i + self.BATCH_SIZE],
-                metadatas=metadatas[i : i + self.BATCH_SIZE],
-                ids=ids[i : i + self.BATCH_SIZE],
+                documents=documents[i : i + self.config.batch_size],
+                metadatas=metadatas[i : i + self.config.batch_size],
+                ids=ids[i : i + self.config.batch_size],
             )
 
     @staticmethod

+ 3 - 3
embedchain/vectordb/elasticsearch.py

@@ -23,8 +23,6 @@ class ElasticsearchDB(BaseVectorDB):
     Elasticsearch as vector database
     """
 
-    BATCH_SIZE = 100
-
     def __init__(
         self,
         config: Optional[ElasticsearchDBConfig] = None,
@@ -140,7 +138,9 @@ class ElasticsearchDB(BaseVectorDB):
         embeddings = self.embedder.embedding_fn(documents)
 
         for chunk in chunks(
-            list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch"
+            list(zip(ids, documents, metadatas, embeddings)),
+            self.config.batch_size,
+            desc="Inserting batches in elasticsearch",
         ):  # noqa: E501
             ids, docs, metadatas, embeddings = [], [], [], []
             for id, text, metadata, embedding in chunk:

+ 0 - 2
embedchain/vectordb/lancedb.py

@@ -18,8 +18,6 @@ class LanceDB(BaseVectorDB):
     LanceDB as vector database
     """
 
-    BATCH_SIZE = 100
-
     def __init__(
         self,
         config: Optional[LanceDBConfig] = None,

+ 4 - 4
embedchain/vectordb/opensearch.py

@@ -28,8 +28,6 @@ class OpenSearchDB(BaseVectorDB):
     OpenSearch as vector database
     """
 
-    BATCH_SIZE = 100
-
     def __init__(self, config: OpenSearchDBConfig):
         """OpenSearch as vector database.
 
@@ -120,8 +118,10 @@ class OpenSearchDB(BaseVectorDB):
         """Adds documents to the opensearch index"""
 
         embeddings = self.embedder.embedding_fn(documents)
-        for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
-            batch_end = batch_start + self.BATCH_SIZE
+        for batch_start in tqdm(
+            range(0, len(documents), self.config.batch_size), desc="Inserting batches in opensearch"
+        ):
+            batch_end = batch_start + self.config.batch_size
             batch_documents = documents[batch_start:batch_end]
             batch_embeddings = embeddings[batch_start:batch_end]
 

+ 3 - 6
embedchain/vectordb/pinecone.py

@@ -25,8 +25,6 @@ class PineconeDB(BaseVectorDB):
     Pinecone as vector database
     """
 
-    BATCH_SIZE = 100
-
     def __init__(
         self,
         config: Optional[PineconeDBConfig] = None,
@@ -103,10 +101,9 @@ class PineconeDB(BaseVectorDB):
         existing_ids = list()
         metadatas = []
 
-        batch_size = 100
         if ids is not None:
-            for i in range(0, len(ids), batch_size):
-                result = self.pinecone_index.fetch(ids=ids[i : i + batch_size])
+            for i in range(0, len(ids), self.config.batch_size):
+                result = self.pinecone_index.fetch(ids=ids[i : i + self.config.batch_size])
                 vectors = result.get("vectors")
                 batch_existing_ids = list(vectors.keys())
                 existing_ids.extend(batch_existing_ids)
@@ -145,7 +142,7 @@ class PineconeDB(BaseVectorDB):
                 },
             )
 
-        for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
+        for chunk in chunks(docs, self.config.batch_size, desc="Adding chunks in batches"):
             self.pinecone_index.upsert(chunk, **kwargs)
 
     def query(

+ 6 - 8
embedchain/vectordb/qdrant.py

@@ -21,8 +21,6 @@ class QdrantDB(BaseVectorDB):
     Qdrant as vector database
     """
 
-    BATCH_SIZE = 10
-
     def __init__(self, config: QdrantDBConfig = None):
         """
         Qdrant as vector database
@@ -116,7 +114,7 @@ class QdrantDB(BaseVectorDB):
                 collection_name=self.collection_name,
                 scroll_filter=models.Filter(must=qdrant_must_filters),
                 offset=offset,
-                limit=self.BATCH_SIZE,
+                limit=self.config.batch_size,
             )
             offset = response[1]
             for doc in response[0]:
@@ -148,13 +146,13 @@ class QdrantDB(BaseVectorDB):
             qdrant_ids.append(id)
             payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
 
-        for i in tqdm(range(0, len(qdrant_ids), self.BATCH_SIZE), desc="Adding data in batches"):
+        for i in tqdm(range(0, len(qdrant_ids), self.config.batch_size), desc="Adding data in batches"):
             self.client.upsert(
                 collection_name=self.collection_name,
                 points=Batch(
-                    ids=qdrant_ids[i : i + self.BATCH_SIZE],
-                    payloads=payloads[i : i + self.BATCH_SIZE],
-                    vectors=embeddings[i : i + self.BATCH_SIZE],
+                    ids=qdrant_ids[i : i + self.config.batch_size],
+                    payloads=payloads[i : i + self.config.batch_size],
+                    vectors=embeddings[i : i + self.config.batch_size],
                 ),
                 **kwargs,
             )
@@ -251,4 +249,4 @@ class QdrantDB(BaseVectorDB):
 
     def delete(self, where: dict):
         db_filter = self._generate_query(where)
-        self.client.delete(collection_name=self.collection_name, points_selector=db_filter)
+        self.client.delete(collection_name=self.collection_name, points_selector=db_filter)

+ 2 - 4
embedchain/vectordb/weaviate.py

@@ -20,8 +20,6 @@ class WeaviateDB(BaseVectorDB):
     Weaviate as vector database
     """
 
-    BATCH_SIZE = 100
-
     def __init__(
         self,
         config: Optional[WeaviateDBConfig] = None,
@@ -169,7 +167,7 @@ class WeaviateDB(BaseVectorDB):
                 )
                 .with_where(weaviate_where_clause)
                 .with_additional(["id"])
-                .with_limit(limit or self.BATCH_SIZE),
+                .with_limit(limit or self.config.batch_size),
                 offset,
             )
 
@@ -198,7 +196,7 @@ class WeaviateDB(BaseVectorDB):
         :type ids: list[str]
         """
         embeddings = self.embedder.embedding_fn(documents)
-        self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3)  # Configure batch
+        self.client.batch.configure(batch_size=self.config.batch_size, timeout_retries=3)  # Configure batch
         with self.client.batch as batch:  # Initialize a batch process
             for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
                 doc = {"identifier": id, "text": text}

+ 1 - 1
tests/vectordb/test_weaviate.py

@@ -124,7 +124,7 @@ class TestWeaviateDb(unittest.TestCase):
         db = WeaviateDB()
         app_config = AppConfig(collect_metrics=False)
         App(config=app_config, db=db, embedding_model=embedder)
-        db.BATCH_SIZE = 1
+        db.config.batch_size = 1
 
         documents = ["This is test document"]
         metadatas = [None]