瀏覽代碼

feat: Add embedding manager (#570)

Taranjeet Singh 1 年之前
父節點
當前提交
2bd6881361

+ 6 - 2
embedchain/chunkers/base_chunker.py

@@ -22,14 +22,17 @@ class BaseChunker(JSONSerializable):
         documents = []
         ids = []
         idMap = {}
-        datas = loader.load_data(src)
+        data_result = loader.load_data(src)
+        data_records = data_result["data"]
+        doc_id = data_result["doc_id"]
         metadatas = []
-        for data in datas:
+        for data in data_records:
             content = data["content"]
 
             meta_data = data["meta_data"]
             # add data type to meta data to allow query using data type
             meta_data["data_type"] = self.data_type.value
+            meta_data["doc_id"] = doc_id
             url = meta_data["url"]
 
             chunks = self.get_chunks(content)
@@ -45,6 +48,7 @@ class BaseChunker(JSONSerializable):
             "documents": documents,
             "ids": ids,
             "metadatas": metadatas,
+            "doc_id": doc_id,
         }
 
     def get_chunks(self, content):

+ 110 - 2
embedchain/embedchain.py

@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple
 
 import requests
 from dotenv import load_dotenv
+from langchain.docstore.document import Document
 from tenacity import retry, stop_after_attempt, wait_fixed
 
 from embedchain.chunkers.base_chunker import BaseChunker
@@ -179,7 +180,7 @@ class EmbedChain(JSONSerializable):
 
         data_formatter = DataFormatter(data_type, config)
         self.user_asks.append([source, data_type.value, metadata])
-        documents, metadatas, _ids, new_chunks = self.load_and_embed(
+        documents, metadatas, _ids, new_chunks = self.load_and_embed_v2(
             data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
         )
         if data_type in {DataType.DOCS_SITE}:
@@ -271,10 +272,11 @@ class EmbedChain(JSONSerializable):
         # get existing ids, and discard doc if any common id exist.
         where = {"app_id": self.config.id} if self.config.id is not None else {}
         # where={"url": src}
-        existing_ids = self.db.get(
+        db_result = self.db.get(
             ids=ids,
             where=where,  # optional filter
         )
+        existing_ids = set(db_result["ids"])
 
         if len(existing_ids):
             data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
@@ -317,6 +319,112 @@ class EmbedChain(JSONSerializable):
         print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
         return list(documents), metadatas, ids, count_new_chunks
 
+    def load_and_embed_v2(
+            self,
+            loader: BaseLoader,
+            chunker: BaseChunker,
+            src: Any,
+            metadata: Optional[Dict[str, Any]] = None,
+            source_id: Optional[str] = None,
+            dry_run = False
+        ):
+        """
+        Loads the data from the given URL, chunks it, and adds it to database.
+
+        :param loader: The loader to use to load the data.
+        :param chunker: The chunker to use to chunk the data.
+        :param src: The data to be handled by the loader. Can be a URL for
+        remote sources or local content for local loaders.
+        :param metadata: Optional. Metadata associated with the data source.
+        :param source_id: Hexadecimal hash of the source.
+        :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
+        """
+        existing_embeddings_data = self.db.get(
+            where={
+                "url": src,
+            },
+            limit=1,
+        )
+        try:
+            existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
+        except Exception:
+            existing_doc_id = None
+        embeddings_data = chunker.create_chunks(loader, src)
+
+        # spread chunking results
+        documents = embeddings_data["documents"]
+        metadatas = embeddings_data["metadatas"]
+        ids = embeddings_data["ids"]
+        new_doc_id = embeddings_data["doc_id"]
+
+        if existing_doc_id and existing_doc_id == new_doc_id:
+            print("Doc content has not changed. Skipping creating chunks and embeddings")
+            return [], [], [], 0
+
+        # this means that doc content has changed.
+        if existing_doc_id and existing_doc_id != new_doc_id:
+            print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
+            self.db.delete({
+                "doc_id": existing_doc_id
+            })
+
+        # get existing ids, and discard doc if any common id exist.
+        where = {"app_id": self.config.id} if self.config.id is not None else {}
+        # where={"url": src}
+        db_result = self.db.get(
+            ids=ids,
+            where=where,  # optional filter
+        )
+        existing_ids = set(db_result["ids"])
+
+        if len(existing_ids):
+            data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
+            data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
+
+            if not data_dict:
+                print(f"All data from {src} already exists in the database.")
+                # Make sure to return a matching return type
+                return [], [], [], 0
+
+            ids = list(data_dict.keys())
+            documents, metadatas = zip(*data_dict.values())
+
+        # Loop though all metadatas and add extras.
+        new_metadatas = []
+        for m in metadatas:
+            # Add app id in metadatas so that they can be queried on later
+            if self.config.id:
+                m["app_id"] = self.config.id
+
+            # Add hashed source
+            m["hash"] = source_id
+
+            # Note: Metadata is the function argument
+            if metadata:
+                # Spread whatever is in metadata into the new object.
+                m.update(metadata)
+
+            new_metadatas.append(m)
+        metadatas = new_metadatas
+
+        # Count before, to calculate a delta in the end.
+        chunks_before_addition = self.count()
+
+        self.db.add(documents=documents, metadatas=metadatas, ids=ids)
+        count_new_chunks = self.count() - chunks_before_addition
+        print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
+        return list(documents), metadatas, ids, count_new_chunks
+
+    def _format_result(self, results):
+        return [
+            (Document(page_content=result[0], metadata=result[1] or {}), result[2])
+            for result in zip(
+                results["documents"][0],
+                results["metadatas"][0],
+                results["distances"][0],
+            )
+        ]
+
     def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
         """
         Queries the vector database based on the given input query.

+ 8 - 2
embedchain/loaders/csv.py

@@ -1,4 +1,5 @@
 import csv
+import hashlib
 from io import StringIO
 from urllib.parse import urlparse
 
@@ -34,7 +35,7 @@ class CsvLoader(BaseLoader):
     def load_data(content):
         """Load a csv file with headers. Each line is a document"""
         result = []
-
+        lines = []
         with CsvLoader._get_file_content(content) as file:
             first_line = file.readline()
             delimiter = CsvLoader._detect_delimiter(first_line)
@@ -42,5 +43,10 @@ class CsvLoader(BaseLoader):
             reader = csv.DictReader(file, delimiter=delimiter)
             for i, row in enumerate(reader):
                 line = ", ".join([f"{field}: {value}" for field, value in row.items()])
+                lines.append(line)
                 result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
-        return result
+        doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": result
+        }

+ 6 - 1
embedchain/loaders/docs_site_loader.py

@@ -1,3 +1,4 @@
+import hashlib
 import logging
 from urllib.parse import urljoin, urlparse
 
@@ -99,4 +100,8 @@ class DocsSiteLoader(BaseLoader):
         output = []
         for u in all_urls:
             output.extend(self._load_data_from_url(u))
-        return output
+        doc_id = hashlib.sha256((" ".join(all_urls) + url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": output,
+        }

+ 7 - 1
embedchain/loaders/docx_file.py

@@ -1,3 +1,5 @@
+import hashlib
+
 from langchain.document_loaders import Docx2txtLoader
 
 from embedchain.helper.json_serializable import register_deserializable
@@ -15,4 +17,8 @@ class DocxFileLoader(BaseLoader):
         meta_data = data[0].metadata
         meta_data["url"] = "local"
         output.append({"content": content, "meta_data": meta_data})
-        return output
+        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": output,
+        }

+ 14 - 7
embedchain/loaders/local_qna_pair.py

@@ -1,3 +1,5 @@
+import hashlib
+
 from embedchain.helper.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
@@ -8,12 +10,17 @@ class LocalQnaPairLoader(BaseLoader):
         """Load data from a local QnA pair."""
         question, answer = content
         content = f"Q: {question}\nA: {answer}"
+        url = "local"
         meta_data = {
-            "url": "local",
+            "url": url,
+        }
+        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": [
+                {
+                    "content": content,
+                    "meta_data": meta_data,
+                }
+            ]
         }
-        return [
-            {
-                "content": content,
-                "meta_data": meta_data,
-            }
-        ]

+ 14 - 7
embedchain/loaders/local_text.py

@@ -1,3 +1,5 @@
+import hashlib
+
 from embedchain.helper.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
@@ -6,12 +8,17 @@ from embedchain.loaders.base_loader import BaseLoader
 class LocalTextLoader(BaseLoader):
     def load_data(self, content):
         """Load data from a local text file."""
+        url = "local"
         meta_data = {
-            "url": "local",
+            "url": url,
+        }
+        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": [
+                {
+                    "content": content,
+                    "meta_data": meta_data,
+                }
+            ]
         }
-        return [
-            {
-                "content": content,
-                "meta_data": meta_data,
-            }
-        ]

+ 7 - 3
embedchain/loaders/notion.py

@@ -1,3 +1,4 @@
+import hashlib
 import logging
 import os
 
@@ -34,10 +35,13 @@ class NotionLoader(BaseLoader):
 
         # Clean text
         text = clean_string(raw_text)
-
-        return [
+        doc_id = hashlib.sha256((text + source).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": [
             {
                 "content": text,
                 "meta_data": {"url": f"notion-{formatted_id}"},
             }
-        ]
+        ],
+        }

+ 11 - 3
embedchain/loaders/pdf_file.py

@@ -1,3 +1,5 @@
+import hashlib
+
 from langchain.document_loaders import PyPDFLoader
 
 from embedchain.helper.json_serializable import register_deserializable
@@ -10,7 +12,8 @@ class PdfFileLoader(BaseLoader):
     def load_data(self, url):
         """Load data from a PDF file."""
         loader = PyPDFLoader(url)
-        output = []
+        data = []
+        all_content = []
         pages = loader.load_and_split()
         if not len(pages):
             raise ValueError("No data found")
@@ -19,10 +22,15 @@ class PdfFileLoader(BaseLoader):
             content = clean_string(content)
             meta_data = page.metadata
             meta_data["url"] = url
-            output.append(
+            data.append(
                 {
                     "content": content,
                     "meta_data": meta_data,
                 }
             )
-        return output
+            all_content.append(content)
+        doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": data,
+        }

+ 7 - 1
embedchain/loaders/sitemap.py

@@ -1,3 +1,4 @@
+import hashlib
 import logging
 
 import requests
@@ -30,6 +31,8 @@ class SitemapLoader(BaseLoader):
             # Get all <loc> tags as a fallback. This might include images.
             links = [link.text for link in soup.find_all("loc")]
 
+        doc_id = hashlib.sha256((" ".join(links) + sitemap_url).encode()).hexdigest()
+
         for link in links:
             try:
                 each_load_data = web_page_loader.load_data(link)
@@ -40,4 +43,7 @@ class SitemapLoader(BaseLoader):
                     logging.warning(f"Page is not readable (too many invalid characters): {link}")
             except ParserRejectedMarkup as e:
                 logging.error(f"Failed to parse {link}: {e}")
-        return [data[0] for data in output]
+        return {
+            "doc_id": doc_id,
+            "data": [data[0] for data in output]
+        }

+ 12 - 7
embedchain/loaders/web_page.py

@@ -1,3 +1,4 @@
+import hashlib
 import logging
 
 import requests
@@ -63,10 +64,14 @@ class WebPageLoader(BaseLoader):
         meta_data = {
             "url": url,
         }
-
-        return [
-            {
-                "content": content,
-                "meta_data": meta_data,
-            }
-        ]
+        content = content
+        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+        return  {
+            "doc_id": doc_id,
+            "data": [
+                {
+                    "content": content,
+                    "meta_data": meta_data,
+                }
+            ],
+        }

+ 8 - 1
embedchain/loaders/youtube_video.py

@@ -1,3 +1,5 @@
+import hashlib
+
 from langchain.document_loaders import YoutubeLoader
 
 from embedchain.helper.json_serializable import register_deserializable
@@ -18,10 +20,15 @@ class YoutubeVideoLoader(BaseLoader):
         content = clean_string(content)
         meta_data = doc[0].metadata
         meta_data["url"] = url
+
         output.append(
             {
                 "content": content,
                 "meta_data": meta_data,
             }
         )
-        return output
+        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": output,
+        }

+ 50 - 0
embedchain/vectordb/base_vector_db.py

@@ -0,0 +1,50 @@
+from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
+from embedchain.embedder.base_embedder import BaseEmbedder
+from embedchain.helper_classes.json_serializable import JSONSerializable
+
+
+class BaseVectorDB(JSONSerializable):
+    """Base class for vector database."""
+
+    def __init__(self, config: BaseVectorDbConfig):
+        self.client = self._get_or_create_db()
+        self.config: BaseVectorDbConfig = config
+
+    def _initialize(self):
+        """
+        This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
+
+        So it's can't be done in __init__ in one step.
+        """
+        raise NotImplementedError
+
+    def _get_or_create_db(self):
+        """Get or create the database."""
+        raise NotImplementedError
+
+    def _get_or_create_collection(self):
+        raise NotImplementedError
+
+    def _set_embedder(self, embedder: BaseEmbedder):
+        self.embedder = embedder
+
+    def get(self):
+        raise NotImplementedError
+
+    def add(self):
+        raise NotImplementedError
+
+    def query(self):
+        raise NotImplementedError
+
+    def count(self):
+        raise NotImplementedError
+
+    def delete(self):
+        raise NotImplementedError
+
+    def reset(self):
+        raise NotImplementedError
+
+    def set_collection_name(self, name: str):
+        raise NotImplementedError

+ 20 - 10
embedchain/vectordb/chroma.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Any
 
 from chromadb import Collection, QueryResult
 from langchain.docstore.document import Document
@@ -87,25 +87,32 @@ class ChromaDB(BaseVectorDB):
         )
         return self.collection
 
-    def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
+    def get(self, ids=None, where=None, limit=None):
         """
         Get existing doc ids present in vector database
 
         :param ids: list of doc ids to check for existence
         :type ids: List[str]
         :param where: Optional. to filter data
-        :type where: Dict[str, any]
+        :type where: Dict[str, Any]
         :return: Existing documents.
         :rtype: List[str]
         """
-        existing_docs = self.collection.get(
-            ids=ids,
-            where=where,  # optional filter
+        args = {}
+        if ids:
+            args["ids"] = ids
+        if where:
+            args["where"] = where
+        if limit:
+            args["limit"] = limit
+        return self.collection.get(
+            **args
         )
 
-        return set(existing_docs["ids"])
+    def get_advanced(self, where):
+        return self.collection.get(where=where, limit=1)
 
-    def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
+    def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
         """
         Add vectors to chroma database
 
@@ -136,7 +143,7 @@ class ChromaDB(BaseVectorDB):
             )
         ]
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
+    def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
         """
         Query contents from vector data base based on vector similarity
 
@@ -145,7 +152,7 @@ class ChromaDB(BaseVectorDB):
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: to filter data
-        :type where: Dict[str, any]
+        :type where: Dict[str, Any]
         :raises InvalidDimensionException: Dimensions do not match.
         :return: The content of the document that matched your query.
         :rtype: List[str]
@@ -187,6 +194,9 @@ class ChromaDB(BaseVectorDB):
         """
         return self.collection.count()
 
+    def delete(self, where):
+        return self.collection.delete(where=where)
+
     def reset(self):
         """
         Resets the database. Deletes all embeddings irreversibly.

+ 9 - 6
tests/chunkers/test_text.py

@@ -69,9 +69,12 @@ class MockLoader:
         Mock loader that returns a list of data dictionaries.
         Adjust this method to return different data for testing.
         """
-        return [
-            {
-                "content": src,
-                "meta_data": {"url": "none"},
-            }
-        ]
+        return {
+            "doc_id": "123",
+            "data": [
+                {
+                    "content": src,
+                    "meta_data": {"url": "none"},
+                }
+            ]
+        }

+ 22 - 20
tests/loaders/test_csv.py

@@ -29,18 +29,19 @@ def test_load_data(delimiter):
         # Loading CSV using CsvLoader
         loader = CsvLoader()
         result = loader.load_data(filename)
+        data = result["data"]
 
         # Assertions
-        assert len(result) == 3
-        assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
-        assert result[0]["meta_data"]["url"] == filename
-        assert result[0]["meta_data"]["row"] == 1
-        assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
-        assert result[1]["meta_data"]["url"] == filename
-        assert result[1]["meta_data"]["row"] == 2
-        assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
-        assert result[2]["meta_data"]["url"] == filename
-        assert result[2]["meta_data"]["row"] == 3
+        assert len(data) == 3
+        assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
+        assert data[0]["meta_data"]["url"] == filename
+        assert data[0]["meta_data"]["row"] == 1
+        assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
+        assert data[1]["meta_data"]["url"] == filename
+        assert data[1]["meta_data"]["row"] == 2
+        assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
+        assert data[2]["meta_data"]["url"] == filename
+        assert data[2]["meta_data"]["row"] == 3
 
         # Cleaning up the temporary file
         os.unlink(filename)
@@ -67,18 +68,19 @@ def test_load_data_with_file_uri(delimiter):
         # Loading CSV using CsvLoader
         loader = CsvLoader()
         result = loader.load_data(filename)
+        data = result["data"]
 
         # Assertions
-        assert len(result) == 3
-        assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
-        assert result[0]["meta_data"]["url"] == filename
-        assert result[0]["meta_data"]["row"] == 1
-        assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
-        assert result[1]["meta_data"]["url"] == filename
-        assert result[1]["meta_data"]["row"] == 2
-        assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
-        assert result[2]["meta_data"]["url"] == filename
-        assert result[2]["meta_data"]["row"] == 3
+        assert len(data) == 3
+        assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
+        assert data[0]["meta_data"]["url"] == filename
+        assert data[0]["meta_data"]["row"] == 1
+        assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
+        assert data[1]["meta_data"]["url"] == filename
+        assert data[1]["meta_data"]["row"] == 2
+        assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
+        assert data[2]["meta_data"]["url"] == filename
+        assert data[2]["meta_data"]["row"] == 3
 
         # Cleaning up the temporary file
         os.unlink(tmpfile.name)