|
@@ -133,7 +133,9 @@ class EmbedChain(JSONSerializable):
|
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
|
config: Optional[AddConfig] = None,
|
|
|
dry_run=False,
|
|
|
- **kwargs: Dict[str, Any],
|
|
|
+ loader: Optional[BaseLoader] = None,
|
|
|
+ chunker: Optional[BaseChunker] = None,
|
|
|
+ **kwargs: Optional[Dict[str, Any]],
|
|
|
):
|
|
|
"""
|
|
|
Adds the data from the given URL to the vector db.
|
|
@@ -192,9 +194,9 @@ class EmbedChain(JSONSerializable):
|
|
|
|
|
|
self.user_asks.append([source, data_type.value, metadata])
|
|
|
|
|
|
- data_formatter = DataFormatter(data_type, config, kwargs)
|
|
|
+ data_formatter = DataFormatter(data_type, config, loader, chunker)
|
|
|
documents, metadatas, _ids, new_chunks = self._load_and_embed(
|
|
|
- data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
|
|
|
+ data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
|
|
|
)
|
|
|
if data_type in {DataType.DOCS_SITE}:
|
|
|
self.is_docs_site_instance = True
|
|
@@ -238,7 +240,7 @@ class EmbedChain(JSONSerializable):
|
|
|
data_type: Optional[DataType] = None,
|
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
|
config: Optional[AddConfig] = None,
|
|
|
- **kwargs: Dict[str, Any],
|
|
|
+ **kwargs: Optional[Dict[str, Any]],
|
|
|
):
|
|
|
"""
|
|
|
Adds the data from the given URL to the vector db.
|
|
@@ -269,7 +271,7 @@ class EmbedChain(JSONSerializable):
|
|
|
data_type=data_type,
|
|
|
metadata=metadata,
|
|
|
config=config,
|
|
|
- kwargs=kwargs,
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
|
|
|
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
|
|
@@ -338,6 +340,7 @@ class EmbedChain(JSONSerializable):
|
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
|
source_hash: Optional[str] = None,
|
|
|
dry_run=False,
|
|
|
+ **kwargs: Optional[Dict[str, Any]],
|
|
|
):
|
|
|
"""
|
|
|
Loads the data from the given URL, chunks it, and adds it to database.
|
|
@@ -431,6 +434,7 @@ class EmbedChain(JSONSerializable):
|
|
|
metadatas=metadatas,
|
|
|
ids=ids,
|
|
|
skip_embedding=(chunker.data_type == DataType.IMAGES),
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
count_new_chunks = self.db.count() - chunks_before_addition
|
|
|
|
|
@@ -448,7 +452,12 @@ class EmbedChain(JSONSerializable):
|
|
|
]
|
|
|
|
|
|
def _retrieve_from_database(
|
|
|
- self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
|
|
|
+ self,
|
|
|
+ input_query: str,
|
|
|
+ config: Optional[BaseLlmConfig] = None,
|
|
|
+ where=None,
|
|
|
+ citations: bool = False,
|
|
|
+ **kwargs: Optional[Dict[str, Any]],
|
|
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
|
|
"""
|
|
|
Queries the vector database based on the given input query.
|
|
@@ -492,6 +501,7 @@ class EmbedChain(JSONSerializable):
|
|
|
where=where,
|
|
|
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
|
|
citations=citations,
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
|
|
|
return contexts
|
|
@@ -526,9 +536,13 @@ class EmbedChain(JSONSerializable):
|
|
|
or the dry run result
|
|
|
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
|
|
"""
|
|
|
- citations = kwargs.get("citations", False)
|
|
|
+ if "citations" in kwargs:
|
|
|
+ citations = kwargs.pop("citations")
|
|
|
+ else:
|
|
|
+ citations = False
|
|
|
+
|
|
|
contexts = self._retrieve_from_database(
|
|
|
- input_query=input_query, config=config, where=where, citations=citations
|
|
|
+ input_query=input_query, config=config, where=where, citations=citations, **kwargs
|
|
|
)
|
|
|
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
|
|
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
|
@@ -579,9 +593,13 @@ class EmbedChain(JSONSerializable):
|
|
|
or the dry run result
|
|
|
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
|
|
"""
|
|
|
- citations = kwargs.get("citations", False)
|
|
|
+ if "citations" in kwargs:
|
|
|
+ citations = kwargs.pop("citations")
|
|
|
+ else:
|
|
|
+ citations = False
|
|
|
+
|
|
|
contexts = self._retrieve_from_database(
|
|
|
- input_query=input_query, config=config, where=where, citations=citations
|
|
|
+ input_query=input_query, config=config, where=where, citations=citations, **kwargs
|
|
|
)
|
|
|
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
|
|
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|