瀏覽代碼

Add metadata support to added data sources (#253)

Rayhan Patel 2 年之前
父節點
當前提交
8f42ced9b5
共有 1 個文件被更改,包括 14 次插入6 次删除
  1. 14 6
      embedchain/embedchain.py

+ 14 - 6
embedchain/embedchain.py

@@ -36,7 +36,7 @@ class EmbedChain:
         self.collection = self.config.db.collection
         self.user_asks = []
 
-    def add(self, data_type, url, config: AddConfig = None):
+    def add(self, data_type, url, metadata=None, config: AddConfig = None):
         """
         Adds the data from the given URL to the vector db.
         Loads the data, chunks it, create embedding for each chunk
@@ -44,6 +44,7 @@ class EmbedChain:
 
         :param data_type: The type of the data to add.
         :param url: The URL where the data is located.
+        :param metadata: Optional. Metadata associated with the data source.
         :param config: Optional. The `AddConfig` instance to use as configuration
         options.
         """
@@ -51,10 +52,10 @@ class EmbedChain:
             config = AddConfig()
 
         data_formatter = DataFormatter(data_type, config)
-        self.user_asks.append([data_type, url])
-        self.load_and_embed(data_formatter.loader, data_formatter.chunker, url)
+        self.user_asks.append([data_type, url, metadata])
+        self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata)
 
-    def add_local(self, data_type, content, config: AddConfig = None):
+    def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
         """
         Adds the data you supply to the vector db.
         Loads the data, chunks it, create embedding for each chunk
@@ -62,6 +63,7 @@ class EmbedChain:
 
         :param data_type: The type of the data to add.
         :param content: The local data. Refer to the `README` for formatting.
+        :param metadata: Optional. Metadata associated with the data source.
         :param config: Optional. The `AddConfig` instance to use as
         configuration options.
         """
@@ -74,9 +76,10 @@ class EmbedChain:
             data_formatter.loader,
             data_formatter.chunker,
             content,
+            metadata,
         )
 
-    def load_and_embed(self, loader, chunker, src):
+    def load_and_embed(self, loader, chunker, src, metadata=None):
         """
         Loads the data from the given URL, chunks it, and adds it to database.
 
@@ -84,6 +87,7 @@ class EmbedChain:
         :param chunker: The chunker to use to chunk the data.
         :param src: The data to be handled by the loader. Can be a URL for
         remote sources or local content for local loaders.
+        :param metadata: Optional. Metadata associated with the data source.
         """
         embeddings_data = chunker.create_chunks(loader, src)
         documents = embeddings_data["documents"]
@@ -112,7 +116,11 @@ class EmbedChain:
             documents, metadatas = zip(*data_dict.values())
 
         chunks_before_addition = self.count()
-        self.collection.add(documents=documents, metadatas=list(metadatas), ids=ids)
+
+         # Add metadata to each document
+        metadatas_with_metadata = [meta or metadata for meta in metadatas]
+
+        self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids)
         print(
             (
                 f"Successfully saved {src}. New chunks count: "