Преглед на файлове

fix: url metadata for all datatypes (#613)

cachho преди 1 година
родител
ревизия
79efa51941
променени са 4 файла, в които са добавени 116 реда и са изтрити 16 реда
  1. 48 10
      embedchain/embedchain.py
  2. 1 3
      embedchain/loaders/local_qna_pair.py
  3. 35 3
      embedchain/models/data_type.py
  4. 32 0
      tests/models/test_data_type.py

+ 48 - 10
embedchain/embedchain.py

@@ -21,7 +21,8 @@ from embedchain.embedder.base import BaseEmbedder
 from embedchain.helper.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import DataType
+from embedchain.models.data_type import (DataType, DirectDataType,
+                                         IndirectDataType, SpecialDataType)
 from embedchain.utils import detect_datatype
 from embedchain.vectordb.base import BaseVectorDB
 
@@ -339,16 +340,53 @@ class EmbedChain(JSONSerializable):
         :param source_id: Hexadecimal hash of the source.
         :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
         """
-        existing_embeddings_data = self.db.get(
-            where={
-                "url": src,
-            },
-            limit=1,
-        )
-        try:
-            existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
-        except Exception:
+        # Find existing embeddings for the source
+        # Depending on the data type, existing embeddings are checked for.
+        if chunker.data_type.value in [item.value for item in DirectDataType]:
+            # DirectDataTypes can't be updated.
+            # Think of a text:
+            #   Either it's the same, then it won't change, so it's not an update.
+            #   Or it's different, then it will be added as a new text.
             existing_doc_id = None
+        elif chunker.data_type.value in [item.value for item in IndirectDataType]:
+            # These types have a indirect source reference
+            # As long as the reference is the same, they can be updated.
+            existing_embeddings_data = self.db.get(
+                where={
+                    "url": src,
+                },
+                limit=1,
+            )
+            try:
+                existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
+            except Exception:
+                existing_doc_id = None
+        elif chunker.data_type.value in [item.value for item in SpecialDataType]:
+            # These types don't contain indirect references.
+            # Through custom logic, they can be attributed to a source and be updated.
+            if chunker.data_type == DataType.QNA_PAIR:
+                # QNA_PAIRs update the answer if the question already exists.
+                existing_embeddings_data = self.db.get(
+                    where={
+                        "question": src[0],
+                    },
+                    limit=1,
+                )
+                try:
+                    existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
+                except Exception:
+                    existing_doc_id = None
+            else:
+                raise NotImplementedError(
+                    f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data"
+                )
+        else:
+            raise TypeError(
+                f"{chunker.data_type} is type {type(chunker.data_type)}. "
+                "When it should be  DirectDataType, IndirectDataType or SpecialDataType."
+            )
+
+        # Create chunks
         embeddings_data = chunker.create_chunks(loader, src)
 
         # spread chunking results

+ 1 - 3
embedchain/loaders/local_qna_pair.py

@@ -11,9 +11,7 @@ class LocalQnaPairLoader(BaseLoader):
         question, answer = content
         content = f"Q: {question}\nA: {answer}"
         url = "local"
-        meta_data = {
-            "url": url,
-        }
+        meta_data = {"url": url, "question": question}
         doc_id = hashlib.sha256((content + url).encode()).hexdigest()
         return {
             "doc_id": doc_id,

+ 35 - 3
embedchain/models/data_type.py

@@ -1,15 +1,47 @@
 from enum import Enum
 
 
-class DataType(Enum):
+class DirectDataType(Enum):
+    """
+    DirectDataType enum contains data types that contain raw data directly.
+    """
+
+    TEXT = "text"
+
+
+class IndirectDataType(Enum):
+    """
+    IndirectDataType enum contains data types that contain references to data stored elsewhere.
+    """
+
     YOUTUBE_VIDEO = "youtube_video"
     PDF_FILE = "pdf_file"
     WEB_PAGE = "web_page"
     SITEMAP = "sitemap"
     DOCX = "docx"
     DOCS_SITE = "docs_site"
-    TEXT = "text"
-    QNA_PAIR = "qna_pair"
     NOTION = "notion"
     CSV = "csv"
     MDX = "mdx"
+
+
+class SpecialDataType(Enum):
+    """
+    SpecialDataType enum contains data types that are neither direct nor indirect, or simply require special attention.
+    """
+
+    QNA_PAIR = "qna_pair"
+
+
+class DataType(Enum):
+    TEXT = DirectDataType.TEXT.value
+    YOUTUBE_VIDEO = IndirectDataType.YOUTUBE_VIDEO.value
+    PDF_FILE = IndirectDataType.PDF_FILE.value
+    WEB_PAGE = IndirectDataType.WEB_PAGE.value
+    SITEMAP = IndirectDataType.SITEMAP.value
+    DOCX = IndirectDataType.DOCX.value
+    DOCS_SITE = IndirectDataType.DOCS_SITE.value
+    NOTION = IndirectDataType.NOTION.value
+    CSV = IndirectDataType.CSV.value
+    MDX = IndirectDataType.MDX.value
+    QNA_PAIR = SpecialDataType.QNA_PAIR.value

+ 32 - 0
tests/models/test_data_type.py

@@ -0,0 +1,32 @@
+import unittest
+
+from embedchain.models.data_type import (DataType, DirectDataType,
+                                         IndirectDataType, SpecialDataType)
+
+
+class TestDataTypeEnums(unittest.TestCase):
+    def test_subclass_types_in_data_type(self):
+        """Test that all data type category subclasses are contained in the composite data type"""
+        # Check if DirectDataType values are in DataType
+        for data_type in DirectDataType:
+            self.assertIn(data_type.value, DataType._value2member_map_)
+
+        # Check if IndirectDataType values are in DataType
+        for data_type in IndirectDataType:
+            self.assertIn(data_type.value, DataType._value2member_map_)
+
+        # Check if SpecialDataType values are in DataType
+        for data_type in SpecialDataType:
+            self.assertIn(data_type.value, DataType._value2member_map_)
+
+    def test_data_type_in_subclasses(self):
+        """Test that all data types in the composite data type are categorized in a subclass"""
+        for data_type in DataType:
+            if data_type.value in DirectDataType._value2member_map_:
+                self.assertIn(data_type.value, DirectDataType._value2member_map_)
+            elif data_type.value in IndirectDataType._value2member_map_:
+                self.assertIn(data_type.value, IndirectDataType._value2member_map_)
+            elif data_type.value in SpecialDataType._value2member_map_:
+                self.assertIn(data_type.value, SpecialDataType._value2member_map_)
+            else:
+                self.fail(f"{data_type.value} not found in any subclass enums")