Ver código fonte

Fix batch_size for vectordb (#1449)

Dev Khant 1 ano atrás
pai
commit
50c0285cb2

+ 1 - 1
Makefile

@@ -11,7 +11,7 @@ install:
 
 install_all:
 	poetry install --all-extras
-	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface
+	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface psutil
 
 install_es:
 	poetry install --extras elasticsearch

+ 0 - 4
embedchain/config/vectordb/base.py

@@ -10,7 +10,6 @@ class BaseVectorDbConfig(BaseConfig):
         dir: str = "db",
         host: Optional[str] = None,
         port: Optional[str] = None,
-        batch_size: Optional[int] = 100,
         **kwargs,
     ):
         """
@@ -24,8 +23,6 @@ class BaseVectorDbConfig(BaseConfig):
         :type host: Optional[str], optional
         :param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
         :type port: Optional[str], optional
-        :param batch_size: Number of items to insert in one batch, defaults to 100
-        :type batch_size: Optional[int], optional
         :param kwargs: Additional keyword arguments
         :type kwargs: dict
         """
@@ -33,7 +30,6 @@ class BaseVectorDbConfig(BaseConfig):
         self.dir = dir
         self.host = host
         self.port = port
-        self.batch_size = batch_size
         # Assign additional keyword arguments
         if kwargs:
             for key, value in kwargs.items():

+ 4 - 0
embedchain/config/vectordb/chroma.py

@@ -12,6 +12,7 @@ class ChromaDbConfig(BaseVectorDbConfig):
         dir: Optional[str] = None,
         host: Optional[str] = None,
         port: Optional[str] = None,
+        batch_size: Optional[int] = 100,
         allow_reset=False,
         chroma_settings: Optional[dict] = None,
     ):
@@ -26,6 +27,8 @@ class ChromaDbConfig(BaseVectorDbConfig):
         :type host: Optional[str], optional
         :param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
         :type port: Optional[str], optional
+        :param batch_size: Number of items to insert in one batch, defaults to 100
+        :type batch_size: Optional[int], optional
         :param allow_reset: Resets the database. defaults to False
         :type allow_reset: bool
         :param chroma_settings: Chroma settings dict, defaults to None
@@ -34,4 +37,5 @@ class ChromaDbConfig(BaseVectorDbConfig):
 
         self.chroma_settings = chroma_settings
         self.allow_reset = allow_reset
+        self.batch_size = batch_size
         super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

+ 7 - 0
embedchain/config/vectordb/elasticsearch.py

@@ -13,6 +13,7 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
         dir: Optional[str] = None,
         es_url: Union[str, list[str]] = None,
         cloud_id: Optional[str] = None,
+        batch_size: Optional[int] = 100,
         **ES_EXTRA_PARAMS: dict[str, any],
     ):
         """
@@ -24,6 +25,10 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
         :type dir: Optional[str], optional
         :param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
         :type es_url: Union[str, list[str]], optional
+        :param cloud_id: cloud id of the elasticsearch cluster, defaults to None
+        :type cloud_id: Optional[str], optional
+        :param batch_size: Number of items to insert in one batch, defaults to 100
+        :type batch_size: Optional[int], optional
         :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
         :type ES_EXTRA_PARAMS: dict[str, Any], optional
         """
@@ -46,4 +51,6 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
             and not self.ES_EXTRA_PARAMS.get("bearer_auth")
         ):
             self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
+
+        self.batch_size = batch_size
         super().__init__(collection_name=collection_name, dir=dir)

+ 4 - 0
embedchain/config/vectordb/opensearch.py

@@ -13,6 +13,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
         vector_dimension: int = 1536,
         collection_name: Optional[str] = None,
         dir: Optional[str] = None,
+        batch_size: Optional[int] = 100,
         **extra_params: dict[str, any],
     ):
         """
@@ -28,10 +29,13 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
         :type vector_dimension: int, optional
         :param dir: Path to the database directory, where the database is stored, defaults to None
         :type dir: Optional[str], optional
+        :param batch_size: Number of items to insert in one batch, defaults to 100
+        :type batch_size: Optional[int], optional
         """
         self.opensearch_url = opensearch_url
         self.http_auth = http_auth
         self.vector_dimension = vector_dimension
         self.extra_params = extra_params
+        self.batch_size = batch_size
 
         super().__init__(collection_name=collection_name, dir=dir)

+ 2 - 0
embedchain/config/vectordb/pinecone.py

@@ -17,6 +17,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
         serverless_config: Optional[dict[str, any]] = None,
         hybrid_search: bool = False,
         bm25_encoder: any = None,
+        batch_size: Optional[int] = 100,
         **extra_params: dict[str, any],
     ):
         self.metric = metric
@@ -26,6 +27,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
         self.extra_params = extra_params
         self.hybrid_search = hybrid_search
         self.bm25_encoder = bm25_encoder
+        self.batch_size = batch_size
         if pod_config is None and serverless_config is None:
             # If no config is provided, use the default pod spec config
             pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")

+ 4 - 0
embedchain/config/vectordb/qdrant.py

@@ -18,6 +18,7 @@ class QdrantDBConfig(BaseVectorDbConfig):
         hnsw_config: Optional[dict[str, any]] = None,
         quantization_config: Optional[dict[str, any]] = None,
         on_disk: Optional[bool] = None,
+        batch_size: Optional[int] = 10,
         **extra_params: dict[str, any],
     ):
         """
@@ -36,9 +37,12 @@ class QdrantDBConfig(BaseVectorDbConfig):
                 This setting saves RAM by (slightly) increasing the response time.
                 Note: those payload values that are involved in filtering and are indexed - remain in RAM.
         :type on_disk: bool, optional, defaults to None
+        :param batch_size: Number of items to insert in one batch, defaults to 10
+        :type batch_size: Optional[int], optional
         """
         self.hnsw_config = hnsw_config
         self.quantization_config = quantization_config
         self.on_disk = on_disk
+        self.batch_size = batch_size
         self.extra_params = extra_params
         super().__init__(collection_name=collection_name, dir=dir)

+ 2 - 0
embedchain/config/vectordb/weaviate.py

@@ -10,7 +10,9 @@ class WeaviateDBConfig(BaseVectorDbConfig):
         self,
         collection_name: Optional[str] = None,
         dir: Optional[str] = None,
+        batch_size: Optional[int] = 100,
         **extra_params: dict[str, any],
     ):
+        self.batch_size = batch_size
         self.extra_params = extra_params
         super().__init__(collection_name=collection_name, dir=dir)

+ 6 - 4
embedchain/vectordb/chroma.py

@@ -42,6 +42,7 @@ class ChromaDB(BaseVectorDB):
 
         self.settings = Settings(anonymized_telemetry=False)
         self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
+        self.batch_size = self.config.batch_size
         if self.config.chroma_settings:
             for key, value in self.config.chroma_settings.items():
                 if hasattr(self.settings, key):
@@ -153,12 +154,13 @@ class ChromaDB(BaseVectorDB):
                 " Ids size: {}".format(len(documents), len(metadatas), len(ids))
             )
 
-        for i in tqdm(range(0, len(documents), self.config.batch_size), desc="Inserting batches in chromadb"):
+        for i in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in chromadb"):
             self.collection.add(
-                documents=documents[i : i + self.config.batch_size],
-                metadatas=metadatas[i : i + self.config.batch_size],
-                ids=ids[i : i + self.config.batch_size],
+                documents=documents[i : i + self.batch_size],
+                metadatas=metadatas[i : i + self.batch_size],
+                ids=ids[i : i + self.batch_size],
             )
+        self.config
 
     @staticmethod
     def _format_result(results: QueryResult) -> list[tuple[Document, float]]:

+ 2 - 1
embedchain/vectordb/elasticsearch.py

@@ -55,6 +55,7 @@ class ElasticsearchDB(BaseVectorDB):
                 "Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`"  # noqa: E501
             )
 
+        self.batch_size = self.config.batch_size
         # Call parent init here because embedder is needed
         super().__init__(config=self.config)
 
@@ -139,7 +140,7 @@ class ElasticsearchDB(BaseVectorDB):
 
         for chunk in chunks(
             list(zip(ids, documents, metadatas, embeddings)),
-            self.config.batch_size,
+            self.batch_size,
             desc="Inserting batches in elasticsearch",
         ):  # noqa: E501
             ids, docs, metadatas, embeddings = [], [], [], []

+ 3 - 4
embedchain/vectordb/opensearch.py

@@ -37,6 +37,7 @@ class OpenSearchDB(BaseVectorDB):
         if config is None:
             raise ValueError("OpenSearchDBConfig is required")
         self.config = config
+        self.batch_size = self.config.batch_size
         self.client = OpenSearch(
             hosts=[self.config.opensearch_url],
             http_auth=self.config.http_auth,
@@ -118,10 +119,8 @@ class OpenSearchDB(BaseVectorDB):
         """Adds documents to the opensearch index"""
 
         embeddings = self.embedder.embedding_fn(documents)
-        for batch_start in tqdm(
-            range(0, len(documents), self.config.batch_size), desc="Inserting batches in opensearch"
-        ):
-            batch_end = batch_start + self.config.batch_size
+        for batch_start in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in opensearch"):
+            batch_end = batch_start + self.batch_size
             batch_documents = documents[batch_start:batch_end]
             batch_embeddings = embeddings[batch_start:batch_end]
 

+ 4 - 3
embedchain/vectordb/pinecone.py

@@ -48,6 +48,7 @@ class PineconeDB(BaseVectorDB):
 
         # Setup BM25Encoder if sparse vectors are to be used
         self.bm25_encoder = None
+        self.batch_size = self.config.batch_size
         if self.config.hybrid_search:
             logger.info("Initializing BM25Encoder for sparse vectors..")
             self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
@@ -102,8 +103,8 @@ class PineconeDB(BaseVectorDB):
         metadatas = []
 
         if ids is not None:
-            for i in range(0, len(ids), self.config.batch_size):
-                result = self.pinecone_index.fetch(ids=ids[i : i + self.config.batch_size])
+            for i in range(0, len(ids), self.batch_size):
+                result = self.pinecone_index.fetch(ids=ids[i : i + self.batch_size])
                 vectors = result.get("vectors")
                 batch_existing_ids = list(vectors.keys())
                 existing_ids.extend(batch_existing_ids)
@@ -142,7 +143,7 @@ class PineconeDB(BaseVectorDB):
                 },
             )
 
-        for chunk in chunks(docs, self.config.batch_size, desc="Adding chunks in batches"):
+        for chunk in chunks(docs, self.batch_size, desc="Adding chunks in batches"):
             self.pinecone_index.upsert(chunk, **kwargs)
 
     def query(

+ 6 - 5
embedchain/vectordb/qdrant.py

@@ -35,6 +35,7 @@ class QdrantDB(BaseVectorDB):
                     "Please make sure the type is right and that you are passing an instance."
                 )
         self.config = config
+        self.batch_size = self.config.batch_size
         self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
         # Call parent init here because embedder is needed
         super().__init__(config=self.config)
@@ -114,7 +115,7 @@ class QdrantDB(BaseVectorDB):
                 collection_name=self.collection_name,
                 scroll_filter=models.Filter(must=qdrant_must_filters),
                 offset=offset,
-                limit=self.config.batch_size,
+                limit=self.batch_size,
             )
             offset = response[1]
             for doc in response[0]:
@@ -146,13 +147,13 @@ class QdrantDB(BaseVectorDB):
             qdrant_ids.append(id)
             payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
 
-        for i in tqdm(range(0, len(qdrant_ids), self.config.batch_size), desc="Adding data in batches"):
+        for i in tqdm(range(0, len(qdrant_ids), self.batch_size), desc="Adding data in batches"):
             self.client.upsert(
                 collection_name=self.collection_name,
                 points=Batch(
-                    ids=qdrant_ids[i : i + self.config.batch_size],
-                    payloads=payloads[i : i + self.config.batch_size],
-                    vectors=embeddings[i : i + self.config.batch_size],
+                    ids=qdrant_ids[i : i + self.batch_size],
+                    payloads=payloads[i : i + self.batch_size],
+                    vectors=embeddings[i : i + self.batch_size],
                 ),
                 **kwargs,
             )

+ 3 - 2
embedchain/vectordb/weaviate.py

@@ -38,6 +38,7 @@ class WeaviateDB(BaseVectorDB):
                     "Please make sure the type is right and that you are passing an instance."
                 )
             self.config = config
+        self.batch_size = self.config.batch_size
         self.client = weaviate.Client(
             url=os.environ.get("WEAVIATE_ENDPOINT"),
             auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
@@ -167,7 +168,7 @@ class WeaviateDB(BaseVectorDB):
                 )
                 .with_where(weaviate_where_clause)
                 .with_additional(["id"])
-                .with_limit(limit or self.config.batch_size),
+                .with_limit(limit or self.batch_size),
                 offset,
             )
 
@@ -196,7 +197,7 @@ class WeaviateDB(BaseVectorDB):
         :type ids: list[str]
         """
         embeddings = self.embedder.embedding_fn(documents)
-        self.client.batch.configure(batch_size=self.config.batch_size, timeout_retries=3)  # Configure batch
+        self.client.batch.configure(batch_size=self.batch_size, timeout_retries=3)  # Configure batch
         with self.client.batch as batch:  # Initialize a batch process
             for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
                 doc = {"identifier": id, "text": text}

+ 1 - 2
tests/vectordb/test_weaviate.py

@@ -124,7 +124,6 @@ class TestWeaviateDb(unittest.TestCase):
         db = WeaviateDB()
         app_config = AppConfig(collect_metrics=False)
         App(config=app_config, db=db, embedding_model=embedder)
-        db.config.batch_size = 1
 
         documents = ["This is test document"]
         metadatas = [None]
@@ -132,7 +131,7 @@ class TestWeaviateDb(unittest.TestCase):
         db.add(documents, metadatas, ids)
 
         # Check if the document was added to the database.
-        weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
+        weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=100, timeout_retries=3)
         weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
             data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
         )