Przeglądaj źródła

Add support for `dry_run` in `load_and_embed_v2` method (#634)

Dev Khant 1 rok temu
rodzic
commit
1db3e43adf
1 zmienionych plików z 8 dodań i 3 usunięć
  1. 8 3
      embedchain/embedchain.py

+ 8 - 3
embedchain/embedchain.py

@@ -313,9 +313,6 @@ class EmbedChain(JSONSerializable):
             ids = list(data_dict.keys())
             documents, metadatas = zip(*data_dict.values())
 
-        if dry_run:
-            return list(documents), metadatas, ids, 0
-
         # Loop though all metadatas and add extras.
         new_metadatas = []
         for m in metadatas:
@@ -334,6 +331,9 @@ class EmbedChain(JSONSerializable):
             new_metadatas.append(m)
         metadatas = new_metadatas
 
+        if dry_run:
+            return list(documents), metadatas, ids, 0
+
         # Count before, to calculate a delta in the end.
         chunks_before_addition = self.db.count()
 
@@ -410,6 +410,8 @@ class EmbedChain(JSONSerializable):
         remote sources or local content for local loaders.
         :param metadata: Optional. Metadata associated with the data source.
         :param source_id: Hexadecimal hash of the source.
+        :param dry_run: Optional. A dry run returns chunks and doesn't update DB.
+        :type dry_run: bool, defaults to False
         :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
         """
         existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
@@ -474,6 +476,9 @@ class EmbedChain(JSONSerializable):
             new_metadatas.append(m)
         metadatas = new_metadatas
 
+        if dry_run:
+            return list(documents), metadatas, ids, 0
+
         # Count before, to calculate a delta in the end.
         chunks_before_addition = self.count()