|
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
|
import requests
|
|
|
from dotenv import load_dotenv
|
|
|
+from langchain.docstore.document import Document
|
|
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
|
|
|
|
|
from embedchain.chunkers.base_chunker import BaseChunker
|
|
@@ -179,7 +180,7 @@ class EmbedChain(JSONSerializable):
|
|
|
|
|
|
data_formatter = DataFormatter(data_type, config)
|
|
|
self.user_asks.append([source, data_type.value, metadata])
|
|
|
- documents, metadatas, _ids, new_chunks = self.load_and_embed(
|
|
|
+ documents, metadatas, _ids, new_chunks = self.load_and_embed_v2(
|
|
|
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
|
|
|
)
|
|
|
if data_type in {DataType.DOCS_SITE}:
|
|
@@ -271,10 +272,11 @@ class EmbedChain(JSONSerializable):
|
|
|
# get existing ids, and discard doc if any common id exist.
|
|
|
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
|
|
# where={"url": src}
|
|
|
- existing_ids = self.db.get(
|
|
|
+ db_result = self.db.get(
|
|
|
ids=ids,
|
|
|
where=where, # optional filter
|
|
|
)
|
|
|
+ existing_ids = set(db_result["ids"])
|
|
|
|
|
|
if len(existing_ids):
|
|
|
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
|
|
@@ -317,6 +319,112 @@ class EmbedChain(JSONSerializable):
|
|
|
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
|
|
return list(documents), metadatas, ids, count_new_chunks
|
|
|
|
|
|
+ def load_and_embed_v2(
|
|
|
+ self,
|
|
|
+ loader: BaseLoader,
|
|
|
+ chunker: BaseChunker,
|
|
|
+ src: Any,
|
|
|
+ metadata: Optional[Dict[str, Any]] = None,
|
|
|
+ source_id: Optional[str] = None,
|
|
|
+ dry_run = False
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ Loads the data from the given URL, chunks it, and adds it to database.
|
|
|
+
|
|
|
+ :param loader: The loader to use to load the data.
|
|
|
+ :param chunker: The chunker to use to chunk the data.
|
|
|
+ :param src: The data to be handled by the loader. Can be a URL for
|
|
|
+ remote sources or local content for local loaders.
|
|
|
+ :param metadata: Optional. Metadata associated with the data source.
|
|
|
+ :param source_id: Hexadecimal hash of the source.
|
|
|
+ :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
|
|
+ """
|
|
|
+ existing_embeddings_data = self.db.get(
|
|
|
+ where={
|
|
|
+ "url": src,
|
|
|
+ },
|
|
|
+ limit=1,
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
|
|
|
+ except Exception:
|
|
|
+ existing_doc_id = None
|
|
|
+ embeddings_data = chunker.create_chunks(loader, src)
|
|
|
+
|
|
|
+ # spread chunking results
|
|
|
+ documents = embeddings_data["documents"]
|
|
|
+ metadatas = embeddings_data["metadatas"]
|
|
|
+ ids = embeddings_data["ids"]
|
|
|
+ new_doc_id = embeddings_data["doc_id"]
|
|
|
+
|
|
|
+ if existing_doc_id and existing_doc_id == new_doc_id:
|
|
|
+ print("Doc content has not changed. Skipping creating chunks and embeddings")
|
|
|
+ return [], [], [], 0
|
|
|
+
|
|
|
+ # this means that doc content has changed.
|
|
|
+ if existing_doc_id and existing_doc_id != new_doc_id:
|
|
|
+ print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
|
|
|
+ self.db.delete({
|
|
|
+ "doc_id": existing_doc_id
|
|
|
+ })
|
|
|
+
|
|
|
+ # get existing ids, and discard doc if any common id exist.
|
|
|
+ where = {"app_id": self.config.id} if self.config.id is not None else {}
|
|
|
+ # where={"url": src}
|
|
|
+ db_result = self.db.get(
|
|
|
+ ids=ids,
|
|
|
+ where=where, # optional filter
|
|
|
+ )
|
|
|
+ existing_ids = set(db_result["ids"])
|
|
|
+
|
|
|
+ if len(existing_ids):
|
|
|
+ data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
|
|
|
+ data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
|
|
|
+
|
|
|
+ if not data_dict:
|
|
|
+ print(f"All data from {src} already exists in the database.")
|
|
|
+ # Make sure to return a matching return type
|
|
|
+ return [], [], [], 0
|
|
|
+
|
|
|
+ ids = list(data_dict.keys())
|
|
|
+ documents, metadatas = zip(*data_dict.values())
|
|
|
+
|
|
|
+ # Loop though all metadatas and add extras.
|
|
|
+ new_metadatas = []
|
|
|
+ for m in metadatas:
|
|
|
+ # Add app id in metadatas so that they can be queried on later
|
|
|
+ if self.config.id:
|
|
|
+ m["app_id"] = self.config.id
|
|
|
+
|
|
|
+ # Add hashed source
|
|
|
+ m["hash"] = source_id
|
|
|
+
|
|
|
+ # Note: Metadata is the function argument
|
|
|
+ if metadata:
|
|
|
+ # Spread whatever is in metadata into the new object.
|
|
|
+ m.update(metadata)
|
|
|
+
|
|
|
+ new_metadatas.append(m)
|
|
|
+ metadatas = new_metadatas
|
|
|
+
|
|
|
+ # Count before, to calculate a delta in the end.
|
|
|
+ chunks_before_addition = self.count()
|
|
|
+
|
|
|
+ self.db.add(documents=documents, metadatas=metadatas, ids=ids)
|
|
|
+ count_new_chunks = self.count() - chunks_before_addition
|
|
|
+ print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
|
|
+ return list(documents), metadatas, ids, count_new_chunks
|
|
|
+
|
|
|
+ def _format_result(self, results):
|
|
|
+ return [
|
|
|
+ (Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
|
|
+ for result in zip(
|
|
|
+ results["documents"][0],
|
|
|
+ results["metadatas"][0],
|
|
|
+ results["distances"][0],
|
|
|
+ )
|
|
|
+ ]
|
|
|
+
|
|
|
def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
|
|
|
"""
|
|
|
Queries the vector database based on the given input query.
|