Sfoglia il codice sorgente

feat: Adding app id in metadata while reading and writing to vector db (#189)

Shashank Srivastava 2 anni fa
parent
commit
d4b8542207
3 ha cambiato i file con 20 aggiunte e 12 eliminazioni
  1. 9 7
      embedchain/config/InitConfig.py
  2. 10 4
      embedchain/embedchain.py
  3. 1 1
      setup.py

+ 9 - 7
embedchain/config/InitConfig.py

@@ -1,32 +1,34 @@
 import logging
 import os
-
 from chromadb.utils import embedding_functions
-
 from embedchain.config.BaseConfig import BaseConfig
 
-
 class InitConfig(BaseConfig):
     """
     Config to initialize an embedchain `App` instance.
     """
-
-    def __init__(self, log_level=None, ef=None, db=None, host=None, port=None):
+    def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
         """
         :param log_level: Optional. (String) Debug level
         ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
         :param ef: Optional. Embedding function to use.
         :param db: Optional. (Vector) database to use for embeddings.
+        :param id: Optional. ID of the app. Document metadata will have this id.
         :param host: Optional. Hostname for the database server.
         :param port: Optional. Port for the database server.
         """
         self._setup_logging(log_level)
 
+        if db is None:
+            from embedchain.vectordb.chroma_db import ChromaDB
+            self.db = ChromaDB(ef=self.ef)
+        else:
+            self.db = db
+        
         self.ef = ef
-        self.db = db
-
         self.host = host
         self.port = port
+        self.id = id
         return
 
     def _set_embedding_function(self, ef):

+ 10 - 4
embedchain/embedchain.py

@@ -97,9 +97,11 @@ class EmbedChain:
         metadatas = embeddings_data["metadatas"]
         ids = embeddings_data["ids"]
         # get existing ids, and discard doc if any common id exist.
+        where={"app_id": self.config.id} if self.config.id is not None else {}
+        # where={"url": src}
         existing_docs = self.collection.get(
             ids=ids,
-            # where={"url": src}
+            where=where, # optional filter
         )
         existing_ids = set(existing_docs["ids"])
 
@@ -113,6 +115,10 @@ class EmbedChain:
 
             ids = list(data_dict.keys())
             documents, metadatas = zip(*data_dict.values())
+        
+        # Add app id in metadatas so that they can be queried on later
+        if (self.config.id is not None):
+            metadatas = [{**m, "app_id": self.config.id} for m in metadatas]
 
         chunks_before_addition = self.count()
 
@@ -144,11 +150,11 @@ class EmbedChain:
         :param config: The query configuration.
         :return: The content of the document that matched your query.
         """
+        where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter
         result = self.collection.query(
-            query_texts=[
-                input_query,
-            ],
+            query_texts=[input_query,],
             n_results=config.number_documents,
+            where=where,
         )
         results_formatted = self._format_result(result)
         contents = [result[0].page_content for result in results_formatted]

+ 1 - 1
setup.py

@@ -33,7 +33,7 @@ setuptools.setup(
         "gpt4all",
         "sentence_transformers",
         "docx2txt",
-        "pydantic==1.10.8",
+        "pydantic==1.10.8"
     ],
     extras_require={"dev": ["black", "ruff", "isort", "pytest"]},
 )