Browse Source

Add metadata support to added data sources (#253)

Rayhan Patel 2 years ago
parent
commit
8f42ced9b5
1 changed files with 14 additions and 6 deletions
  1. 14 6
      embedchain/embedchain.py

+ 14 - 6
embedchain/embedchain.py

@@ -36,7 +36,7 @@ class EmbedChain:
         self.collection = self.config.db.collection
         self.collection = self.config.db.collection
         self.user_asks = []
         self.user_asks = []
 
 
-    def add(self, data_type, url, config: AddConfig = None):
+    def add(self, data_type, url, metadata=None, config: AddConfig = None):
         """
         """
         Adds the data from the given URL to the vector db.
         Adds the data from the given URL to the vector db.
         Loads the data, chunks it, create embedding for each chunk
         Loads the data, chunks it, create embedding for each chunk
@@ -44,6 +44,7 @@ class EmbedChain:
 
 
         :param data_type: The type of the data to add.
         :param data_type: The type of the data to add.
         :param url: The URL where the data is located.
         :param url: The URL where the data is located.
+        :param metadata: Optional. Metadata associated with the data source.
         :param config: Optional. The `AddConfig` instance to use as configuration
         :param config: Optional. The `AddConfig` instance to use as configuration
         options.
         options.
         """
         """
@@ -51,10 +52,10 @@ class EmbedChain:
             config = AddConfig()
             config = AddConfig()
 
 
         data_formatter = DataFormatter(data_type, config)
         data_formatter = DataFormatter(data_type, config)
-        self.user_asks.append([data_type, url])
-        self.load_and_embed(data_formatter.loader, data_formatter.chunker, url)
+        self.user_asks.append([data_type, url, metadata])
+        self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata)
 
 
-    def add_local(self, data_type, content, config: AddConfig = None):
+    def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
         """
         """
         Adds the data you supply to the vector db.
         Adds the data you supply to the vector db.
         Loads the data, chunks it, create embedding for each chunk
         Loads the data, chunks it, create embedding for each chunk
@@ -62,6 +63,7 @@ class EmbedChain:
 
 
         :param data_type: The type of the data to add.
         :param data_type: The type of the data to add.
         :param content: The local data. Refer to the `README` for formatting.
         :param content: The local data. Refer to the `README` for formatting.
+        :param metadata: Optional. Metadata associated with the data source.
         :param config: Optional. The `AddConfig` instance to use as
         :param config: Optional. The `AddConfig` instance to use as
         configuration options.
         configuration options.
         """
         """
@@ -74,9 +76,10 @@ class EmbedChain:
             data_formatter.loader,
             data_formatter.loader,
             data_formatter.chunker,
             data_formatter.chunker,
             content,
             content,
+            metadata,
         )
         )
 
 
-    def load_and_embed(self, loader, chunker, src):
+    def load_and_embed(self, loader, chunker, src, metadata=None):
         """
         """
         Loads the data from the given URL, chunks it, and adds it to database.
         Loads the data from the given URL, chunks it, and adds it to database.
 
 
@@ -84,6 +87,7 @@ class EmbedChain:
         :param chunker: The chunker to use to chunk the data.
         :param chunker: The chunker to use to chunk the data.
         :param src: The data to be handled by the loader. Can be a URL for
         :param src: The data to be handled by the loader. Can be a URL for
         remote sources or local content for local loaders.
         remote sources or local content for local loaders.
+        :param metadata: Optional. Metadata associated with the data source.
         """
         """
         embeddings_data = chunker.create_chunks(loader, src)
         embeddings_data = chunker.create_chunks(loader, src)
         documents = embeddings_data["documents"]
         documents = embeddings_data["documents"]
@@ -112,7 +116,11 @@ class EmbedChain:
             documents, metadatas = zip(*data_dict.values())
             documents, metadatas = zip(*data_dict.values())
 
 
         chunks_before_addition = self.count()
         chunks_before_addition = self.count()
-        self.collection.add(documents=documents, metadatas=list(metadatas), ids=ids)
+
+         # Add metadata to each document
+        metadatas_with_metadata = [meta or metadata for meta in metadatas]
+
+        self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids)
         print(
         print(
             (
             (
                 f"Successfully saved {src}. New chunks count: "
                 f"Successfully saved {src}. New chunks count: "