瀏覽代碼

[OpenSearch] Add chunks specific to an app_id if present (#765)

Deshraj Yadav 1 年之前
父節點
當前提交
64a34cac32

+ 9 - 4
embedchain/chunkers/base_chunker.py

@@ -10,7 +10,7 @@ class BaseChunker(JSONSerializable):
         self.text_splitter = text_splitter
         self.text_splitter = text_splitter
         self.data_type = None
         self.data_type = None
 
 
-    def create_chunks(self, loader, src):
+    def create_chunks(self, loader, src, app_id=None):
         """
         """
         Loads data and chunks it.
         Loads data and chunks it.
 
 
@@ -18,13 +18,18 @@ class BaseChunker(JSONSerializable):
         the raw data.
         the raw data.
         :param src: The data to be handled by the loader. Can be a URL for
         :param src: The data to be handled by the loader. Can be a URL for
         remote sources or local content for local loaders.
         remote sources or local content for local loaders.
+        :param app_id: App id used to generate the doc_id.
         """
         """
         documents = []
         documents = []
-        ids = []
+        chunk_ids = []
         idMap = {}
         idMap = {}
         data_result = loader.load_data(src)
         data_result = loader.load_data(src)
         data_records = data_result["data"]
         data_records = data_result["data"]
         doc_id = data_result["doc_id"]
         doc_id = data_result["doc_id"]
+        # Prefix app_id in the document id if app_id is not None to
+        # distinguish between different documents stored in the same
+        # elasticsearch or opensearch index
+        doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
         metadatas = []
         metadatas = []
         for data in data_records:
         for data in data_records:
             content = data["content"]
             content = data["content"]
@@ -41,12 +46,12 @@ class BaseChunker(JSONSerializable):
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 if idMap.get(chunk_id) is None:
                 if idMap.get(chunk_id) is None:
                     idMap[chunk_id] = True
                     idMap[chunk_id] = True
-                    ids.append(chunk_id)
+                    chunk_ids.append(chunk_id)
                     documents.append(chunk)
                     documents.append(chunk)
                     metadatas.append(meta_data)
                     metadatas.append(meta_data)
         return {
         return {
             "documents": documents,
             "documents": documents,
-            "ids": ids,
+            "ids": chunk_ids,
             "metadatas": metadatas,
             "metadatas": metadatas,
             "doc_id": doc_id,
             "doc_id": doc_id,
         }
         }

+ 2 - 1
embedchain/chunkers/images.py

@@ -20,7 +20,7 @@ class ImagesChunker(BaseChunker):
         )
         )
         super().__init__(image_splitter)
         super().__init__(image_splitter)
 
 
-    def create_chunks(self, loader, src):
+    def create_chunks(self, loader, src, app_id=None):
         """
         """
         Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
         Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
 
 
@@ -35,6 +35,7 @@ class ImagesChunker(BaseChunker):
         data_result = loader.load_data(src)
         data_result = loader.load_data(src)
         data_records = data_result["data"]
         data_records = data_result["data"]
         doc_id = data_result["doc_id"]
         doc_id = data_result["doc_id"]
+        doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
         metadatas = []
         metadatas = []
         for data in data_records:
         for data in data_records:
             meta_data = data["meta_data"]
             meta_data = data["meta_data"]

+ 23 - 19
embedchain/embedchain.py

@@ -268,14 +268,16 @@ class EmbedChain(JSONSerializable):
         elif chunker.data_type.value in [item.value for item in IndirectDataType]:
         elif chunker.data_type.value in [item.value for item in IndirectDataType]:
             # These types have a indirect source reference
             # These types have a indirect source reference
             # As long as the reference is the same, they can be updated.
             # As long as the reference is the same, they can be updated.
-            existing_embeddings_data = self.db.get(
-                where={
-                    "url": src,
-                },
+            where = {"url": src}
+            if self.config.id is not None:
+                where.update({"app_id": self.config.id})
+
+            existing_embeddings = self.db.get(
+                where=where,
                 limit=1,
                 limit=1,
             )
             )
-            if len(existing_embeddings_data.get("metadatas", [])) > 0:
-                return existing_embeddings_data["metadatas"][0]["doc_id"]
+            if len(existing_embeddings.get("metadatas", [])) > 0:
+                return existing_embeddings["metadatas"][0]["doc_id"]
             else:
             else:
                 return None
                 return None
         elif chunker.data_type.value in [item.value for item in SpecialDataType]:
         elif chunker.data_type.value in [item.value for item in SpecialDataType]:
@@ -283,14 +285,16 @@ class EmbedChain(JSONSerializable):
             # Through custom logic, they can be attributed to a source and be updated.
             # Through custom logic, they can be attributed to a source and be updated.
             if chunker.data_type == DataType.QNA_PAIR:
             if chunker.data_type == DataType.QNA_PAIR:
                 # QNA_PAIRs update the answer if the question already exists.
                 # QNA_PAIRs update the answer if the question already exists.
-                existing_embeddings_data = self.db.get(
-                    where={
-                        "question": src[0],
-                    },
+                where = {"question": src[0]}
+                if self.config.id is not None:
+                    where.update({"app_id": self.config.id})
+
+                existing_embeddings = self.db.get(
+                    where=where,
                     limit=1,
                     limit=1,
                 )
                 )
-                if len(existing_embeddings_data.get("metadatas", [])) > 0:
-                    return existing_embeddings_data["metadatas"][0]["doc_id"]
+                if len(existing_embeddings.get("metadatas", [])) > 0:
+                    return existing_embeddings["metadatas"][0]["doc_id"]
                 else:
                 else:
                     return None
                     return None
             else:
             else:
@@ -326,9 +330,10 @@ class EmbedChain(JSONSerializable):
         :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
         :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
         """
         """
         existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
         existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
+        app_id = self.config.id if self.config is not None else None
 
 
         # Create chunks
         # Create chunks
-        embeddings_data = chunker.create_chunks(loader, src)
+        embeddings_data = chunker.create_chunks(loader, src, app_id=app_id)
         # spread chunking results
         # spread chunking results
         documents = embeddings_data["documents"]
         documents = embeddings_data["documents"]
         metadatas = embeddings_data["metadatas"]
         metadatas = embeddings_data["metadatas"]
@@ -345,12 +350,11 @@ class EmbedChain(JSONSerializable):
             self.db.delete({"doc_id": existing_doc_id})
             self.db.delete({"doc_id": existing_doc_id})
 
 
         # get existing ids, and discard doc if any common id exist.
         # get existing ids, and discard doc if any common id exist.
-        where = {"app_id": self.config.id} if self.config.id is not None else {}
-        # where={"url": src}
-        db_result = self.db.get(
-            ids=ids,
-            where=where,  # optional filter
-        )
+        where = {"url": src}
+        if self.config.id is not None:
+            where.update({"metadata.app_id": self.config.id})
+
+        db_result = self.db.get(ids=ids, where=where)  # optional filter
         existing_ids = set(db_result["ids"])
         existing_ids = set(db_result["ids"])
 
 
         if len(existing_ids):
         if len(existing_ids):

+ 22 - 4
embedchain/vectordb/opensearch.py

@@ -85,19 +85,29 @@ class OpenSearchDB(BaseVectorDB):
         :return: ids
         :return: ids
         :type: Set[str]
         :type: Set[str]
         """
         """
+        query = {}
         if ids:
         if ids:
-            query = {"query": {"bool": {"must": [{"ids": {"values": ids}}]}}}
+            query["query"] = {"bool": {"must": [{"ids": {"values": ids}}]}}
         else:
         else:
-            query = {"query": {"bool": {"must": []}}}
+            query["query"] = {"bool": {"must": []}}
+
         if "app_id" in where:
         if "app_id" in where:
             app_id = where["app_id"]
             app_id = where["app_id"]
             query["query"]["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
             query["query"]["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
 
 
         # OpenSearch syntax is different from Elasticsearch
         # OpenSearch syntax is different from Elasticsearch
-        response = self.client.search(index=self._get_index(), body=query, _source=False, size=limit)
+        response = self.client.search(index=self._get_index(), body=query, _source=True, size=limit)
         docs = response["hits"]["hits"]
         docs = response["hits"]["hits"]
         ids = [doc["_id"] for doc in docs]
         ids = [doc["_id"] for doc in docs]
-        return {"ids": set(ids)}
+        doc_ids = [doc["_source"]["metadata"]["doc_id"] for doc in docs]
+
+        # Result is modified for compatibility with other vector databases
+        # TODO: Add method in vector database to return result in a standard format
+        result = {"ids": ids, "metadatas": []}
+
+        for doc_id in doc_ids:
+            result["metadatas"].append({"doc_id": doc_id})
+        return result
 
 
     def add(
     def add(
         self, embeddings: List[str], documents: List[str], metadatas: List[object], ids: List[str], skip_embedding: bool
         self, embeddings: List[str], documents: List[str], metadatas: List[object], ids: List[str], skip_embedding: bool
@@ -204,6 +214,14 @@ class OpenSearchDB(BaseVectorDB):
             # delete index in Es
             # delete index in Es
             self.client.indices.delete(index=self._get_index())
             self.client.indices.delete(index=self._get_index())
 
 
+    def delete(self, where):
+        """Deletes a document from the OpenSearch index"""
+        if "doc_id" not in where:
+            raise ValueError("doc_id is required to delete a document")
+
+        query = {"query": {"bool": {"must": [{"term": {"metadata.doc_id": where["doc_id"]}}]}}}
+        self.client.delete_by_query(index=self._get_index(), body=query)
+
     def _get_index(self) -> str:
     def _get_index(self) -> str:
         """Get the OpenSearch index for a collection
         """Get the OpenSearch index for a collection
 
 

+ 8 - 6
tests/chunkers/test_image_chunker.py

@@ -17,14 +17,15 @@ class TestImageChunker(unittest.TestCase):
         chunker.set_data_type(DataType.IMAGES)
         chunker.set_data_type(DataType.IMAGES)
 
 
         image_path = "./tmp/image.jpeg"
         image_path = "./tmp/image.jpeg"
-        result = chunker.create_chunks(MockLoader(), image_path)
+        app_id = "app1"
+        result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
 
 
         expected_chunks = {
         expected_chunks = {
-            "doc_id": "123",
+            "doc_id": f"{app_id}--123",
             "documents": [image_path],
             "documents": [image_path],
             "embeddings": ["embedding"],
             "embeddings": ["embedding"],
             "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
             "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
-            "metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
+            "metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
         }
         }
         self.assertEqual(expected_chunks, result)
         self.assertEqual(expected_chunks, result)
 
 
@@ -37,14 +38,15 @@ class TestImageChunker(unittest.TestCase):
         chunker.set_data_type(DataType.IMAGES)
         chunker.set_data_type(DataType.IMAGES)
 
 
         image_path = "./tmp/image.jpeg"
         image_path = "./tmp/image.jpeg"
-        result = chunker.create_chunks(MockLoader(), image_path)
+        app_id = "app1"
+        result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
 
 
         expected_chunks = {
         expected_chunks = {
-            "doc_id": "123",
+            "doc_id": f"{app_id}--123",
             "documents": [image_path],
             "documents": [image_path],
             "embeddings": ["embedding"],
             "embeddings": ["embedding"],
             "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
             "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
-            "metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
+            "metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
         }
         }
         self.assertEqual(expected_chunks, result)
         self.assertEqual(expected_chunks, result)
 
 

+ 17 - 21
tests/chunkers/test_text.py

@@ -1,31 +1,35 @@
 # ruff: noqa: E501
 # ruff: noqa: E501
 
 
-import unittest
-
 from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.text import TextChunker
 from embedchain.config import ChunkerConfig
 from embedchain.config import ChunkerConfig
 from embedchain.models.data_type import DataType
 from embedchain.models.data_type import DataType
 
 
 
 
-class TestTextChunker(unittest.TestCase):
-    def test_chunks(self):
+class TestTextChunker:
+    def test_chunks_without_app_id(self):
         """
         """
         Test the chunks generated by TextChunker.
         Test the chunks generated by TextChunker.
-        # TODO: Not a very precise test.
         """
         """
         chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
         chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
         chunker = TextChunker(config=chunker_config)
         chunker = TextChunker(config=chunker_config)
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         # Data type must be set manually in the test
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.TEXT)
         chunker.set_data_type(DataType.TEXT)
-
         result = chunker.create_chunks(MockLoader(), text)
         result = chunker.create_chunks(MockLoader(), text)
-
         documents = result["documents"]
         documents = result["documents"]
+        assert len(documents) > 5
 
 
-        self.assertGreaterEqual(len(documents), 5)
-
-    # Additional test cases can be added to cover different scenarios
+    def test_chunks_with_app_id(self):
+        """
+        Test the chunks generated by TextChunker with app_id
+        """
+        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
+        chunker = TextChunker(config=chunker_config)
+        text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
+        chunker.set_data_type(DataType.TEXT)
+        result = chunker.create_chunks(MockLoader(), text)
+        documents = result["documents"]
+        assert len(documents) > 5
 
 
     def test_big_chunksize(self):
     def test_big_chunksize(self):
         """
         """
@@ -36,12 +40,9 @@ class TestTextChunker(unittest.TestCase):
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         # Data type must be set manually in the test
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.TEXT)
         chunker.set_data_type(DataType.TEXT)
-
         result = chunker.create_chunks(MockLoader(), text)
         result = chunker.create_chunks(MockLoader(), text)
-
         documents = result["documents"]
         documents = result["documents"]
-
-        self.assertEqual(len(documents), 1)
+        assert len(documents) == 1
 
 
     def test_small_chunksize(self):
     def test_small_chunksize(self):
         """
         """
@@ -53,14 +54,9 @@ class TestTextChunker(unittest.TestCase):
         text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
         text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
         # Data type must be set manually in the test
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.TEXT)
         chunker.set_data_type(DataType.TEXT)
-
         result = chunker.create_chunks(MockLoader(), text)
         result = chunker.create_chunks(MockLoader(), text)
-
         documents = result["documents"]
         documents = result["documents"]
-
-        print(documents)
-
-        self.assertEqual(len(documents), len(text))
+        assert len(documents) == len(text)
 
 
     def test_word_count(self):
     def test_word_count(self):
         chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
         chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
@@ -69,7 +65,7 @@ class TestTextChunker(unittest.TestCase):
 
 
         document = ["ab cd", "ef gh"]
         document = ["ab cd", "ef gh"]
         result = chunker.get_word_count(document)
         result = chunker.get_word_count(document)
-        self.assertEqual(result, 4)
+        assert result == 4
 
 
 
 
 class MockLoader:
 class MockLoader: