浏览代码

[BugFix] Fix issue of chunks not getting embedded in opensearch index (#983)

Deshraj Yadav 1 年之前
父节点
当前提交
b02e8feeda
共有 3 个文件被更改,包括 34 次插入33 次删除
  1. 1 1
      embedchain/loaders/discord.py
  2. 32 31
      embedchain/vectordb/opensearch.py
  3. 1 1
      pyproject.toml

+ 1 - 1
embedchain/loaders/discord.py

@@ -1,6 +1,6 @@
+import hashlib
 import logging
 import os
-import hashlib
 
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader

+ 32 - 31
embedchain/vectordb/opensearch.py

@@ -122,41 +122,42 @@ class OpenSearchDB(BaseVectorDB):
         ids: List[str],
         skip_embedding: bool,
     ):
-        """add data in vector database
-
-        :param embeddings: list of embeddings to add
-        :type embeddings: List[List[str]]
-        :param documents: list of texts to add
-        :type documents: List[str]
-        :param metadatas: list of metadata associated with docs
-        :type metadatas: List[object]
-        :param ids: ids of docs
-        :type ids: List[str]
-        :param skip_embedding: Optional. If True, then the embeddings are assumed to be already generated.
-        :type skip_embedding: bool
+        """Add data in vector database.
+
+        Args:
+            embeddings (List[List[str]]): List of embeddings to add.
+            documents (List[str]): List of texts to add.
+            metadatas (List[object]): List of metadata associated with docs.
+            ids (List[str]): IDs of docs.
+            skip_embedding (bool): If True, then embeddings are assumed to be already generated.
         """
+        for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
+            batch_end = batch_start + self.BATCH_SIZE
+            batch_documents = documents[batch_start:batch_end]
 
-        for i in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
+            # Generate embeddings for the batch if not skipping embedding
             if not skip_embedding:
-                embeddings = self.embedder.embedding_fn(documents[i : i + self.BATCH_SIZE])
-
-            docs = []
-            for id, text, metadata, embeddings in zip(
-                ids[i : i + self.BATCH_SIZE],
-                documents[i : i + self.BATCH_SIZE],
-                metadatas[i : i + self.BATCH_SIZE],
-                embeddings[i : i + self.BATCH_SIZE],
-            ):
-                docs.append(
-                    {
-                        "_index": self._get_index(),
-                        "_id": id,
-                        "_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
-                    }
+                batch_embeddings = self.embedder.embedding_fn(batch_documents)
+            else:
+                batch_embeddings = embeddings[batch_start:batch_end]
+
+            # Create document entries for bulk upload
+            batch_entries = [
+                {
+                    "_index": self._get_index(),
+                    "_id": doc_id,
+                    "_source": {"text": text, "metadata": metadata, "embeddings": embedding},
+                }
+                for doc_id, text, metadata, embedding in zip(
+                    ids[batch_start:batch_end], batch_documents, metadatas[batch_start:batch_end], batch_embeddings
                 )
-            bulk(self.client, docs)
+            ]
+
+            # Perform bulk operation
+            bulk(self.client, batch_entries)
             self.client.indices.refresh(index=self._get_index())
-            # Sleep for 0.1 seconds to avoid rate limiting
+
+            # Sleep to avoid rate limiting
             time.sleep(0.1)
 
     def query(
@@ -250,7 +251,7 @@ class OpenSearchDB(BaseVectorDB):
         """
         # Delete all data from the database
         if self.client.indices.exists(index=self._get_index()):
-            # delete index in Es
+            # delete index in ES
             self.client.indices.delete(index=self._get_index())
 
     def delete(self, where):

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.23"
+version = "0.1.24"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",