Переглянути джерело

refactor: get existing doc id method (#616)

cachho 1 рік тому
батько
коміт
3d0e4141bf
1 змінених файлів з 34 додано та 28 видалено
  1. 34 28
      embedchain/embedchain.py

+ 34 - 28
embedchain/embedchain.py

@@ -322,26 +322,10 @@ class EmbedChain(JSONSerializable):
         count_new_chunks = self.db.count() - chunks_before_addition
         print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
         return list(documents), metadatas, ids, count_new_chunks
-
-    def load_and_embed_v2(
-        self,
-        loader: BaseLoader,
-        chunker: BaseChunker,
-        src: Any,
-        metadata: Optional[Dict[str, Any]] = None,
-        source_id: Optional[str] = None,
-        dry_run=False,
-    ):
+    
+    def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
         """
-        Loads the data from the given URL, chunks it, and adds it to database.
-
-        :param loader: The loader to use to load the data.
-        :param chunker: The chunker to use to chunk the data.
-        :param src: The data to be handled by the loader. Can be a URL for
-        remote sources or local content for local loaders.
-        :param metadata: Optional. Metadata associated with the data source.
-        :param source_id: Hexadecimal hash of the source.
-        :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
+        Get id of existing document for a given source, based on the data type
         """
         # Find existing embeddings for the source
         # Depending on the data type, existing embeddings are checked for.
@@ -350,7 +334,7 @@ class EmbedChain(JSONSerializable):
             # Think of a text:
             #   Either it's the same, then it won't change, so it's not an update.
             #   Or it's different, then it will be added as a new text.
-            existing_doc_id = None
+            return None
         elif chunker.data_type.value in [item.value for item in IndirectDataType]:
             # These types have a indirect source reference
             # As long as the reference is the same, they can be updated.
@@ -360,10 +344,10 @@ class EmbedChain(JSONSerializable):
                 },
                 limit=1,
             )
-            try:
-                existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
-            except Exception:
-                existing_doc_id = None
+            if len(existing_embeddings_data.get("metadatas", [])) > 0:
+                return existing_embeddings_data["metadatas"][0]["doc_id"]
+            else:
+                return None
         elif chunker.data_type.value in [item.value for item in SpecialDataType]:
             # These types don't contain indirect references.
             # Through custom logic, they can be attributed to a source and be updated.
@@ -375,10 +359,10 @@ class EmbedChain(JSONSerializable):
                     },
                     limit=1,
                 )
-                try:
-                    existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
-                except Exception:
-                    existing_doc_id = None
+                if len(existing_embeddings_data.get("metadatas", [])) > 0:
+                    return existing_embeddings_data["metadatas"][0]["doc_id"]
+                else:
+                    return None
             else:
                 raise NotImplementedError(
                     f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data"
@@ -389,6 +373,28 @@ class EmbedChain(JSONSerializable):
                 "When it should be  DirectDataType, IndirectDataType or SpecialDataType."
             )
 
+    def load_and_embed_v2(
+        self,
+        loader: BaseLoader,
+        chunker: BaseChunker,
+        src: Any,
+        metadata: Optional[Dict[str, Any]] = None,
+        source_id: Optional[str] = None,
+        dry_run=False,
+    ):
+        """
+        Loads the data from the given URL, chunks it, and adds it to database.
+
+        :param loader: The loader to use to load the data.
+        :param chunker: The chunker to use to chunk the data.
+        :param src: The data to be handled by the loader. Can be a URL for
+        remote sources or local content for local loaders.
+        :param metadata: Optional. Metadata associated with the data source.
+        :param source_id: Hexadecimal hash of the source.
+        :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
+        """
+        existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
+
         # Create chunks
         embeddings_data = chunker.create_chunks(loader, src)