|
@@ -6,11 +6,10 @@ import os
|
|
|
import threading
|
|
|
import uuid
|
|
|
from pathlib import Path
|
|
|
-from typing import Dict, Optional
|
|
|
+from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
|
import requests
|
|
|
from dotenv import load_dotenv
|
|
|
-from langchain.docstore.document import Document
|
|
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
|
|
|
|
|
from embedchain.chunkers.base_chunker import BaseChunker
|
|
@@ -46,8 +45,17 @@ class EmbedChain(JSONSerializable):
|
|
|
Initializes the EmbedChain instance, sets up a vector DB client and
|
|
|
creates a collection.
|
|
|
|
|
|
- :param config: BaseAppConfig instance to load as configuration.
|
|
|
- :param system_prompt: Optional. System prompt string.
|
|
|
+ :param config: Configuration just for the app, not the db or llm or embedder.
|
|
|
+ :type config: BaseAppConfig
|
|
|
+ :param llm: Instance of the LLM you want to use.
|
|
|
+ :type llm: BaseLlm
|
|
|
+ :param db: Instance of the Database to use, defaults to None
|
|
|
+ :type db: BaseVectorDB, optional
|
|
|
+ :param embedder: instance of the embedder to use, defaults to None
|
|
|
+ :type embedder: BaseEmbedder, optional
|
|
|
+ :param system_prompt: System prompt to use in the llm query, defaults to None
|
|
|
+ :type system_prompt: Optional[str], optional
|
|
|
+ :raises ValueError: No database or embedder provided.
|
|
|
"""
|
|
|
|
|
|
self.config = config
|
|
@@ -88,10 +96,13 @@ class EmbedChain(JSONSerializable):
|
|
|
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
|
|
|
thread_telemetry.start()
|
|
|
|
|
|
- def _load_or_generate_user_id(self):
|
|
|
+ def _load_or_generate_user_id(self) -> str:
|
|
|
"""
|
|
|
Loads the user id from the config file if it exists, otherwise generates a new
|
|
|
one and saves it to the config file.
|
|
|
+
|
|
|
+ :return: user id
|
|
|
+ :rtype: str
|
|
|
"""
|
|
|
if not os.path.exists(CONFIG_DIR):
|
|
|
os.makedirs(CONFIG_DIR)
|
|
@@ -110,9 +121,9 @@ class EmbedChain(JSONSerializable):
|
|
|
|
|
|
def add(
|
|
|
self,
|
|
|
- source,
|
|
|
+ source: Any,
|
|
|
data_type: Optional[DataType] = None,
|
|
|
- metadata: Optional[Dict] = None,
|
|
|
+ metadata: Optional[Dict[str, Any]] = None,
|
|
|
config: Optional[AddConfig] = None,
|
|
|
):
|
|
|
"""
|
|
@@ -121,12 +132,17 @@ class EmbedChain(JSONSerializable):
|
|
|
and then stores the embedding to vector database.
|
|
|
|
|
|
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
|
|
|
- :param data_type: Optional. Automatically detected, but can be forced with this argument.
|
|
|
- The type of the data to add.
|
|
|
- :param metadata: Optional. Metadata associated with the data source.
|
|
|
- :param config: Optional. The `AddConfig` instance to use as configuration
|
|
|
- options.
|
|
|
+ :type source: Any
|
|
|
+ :param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
|
|
|
+ defaults to None
|
|
|
+ :type data_type: Optional[DataType], optional
|
|
|
+ :param metadata: Metadata associated with the data source., defaults to None
|
|
|
+ :type metadata: Optional[Dict[str, Any]], optional
|
|
|
+ :param config: The `AddConfig` instance to use as configuration options., defaults to None
|
|
|
+ :type config: Optional[AddConfig], optional
|
|
|
+ :raises ValueError: Invalid data type
|
|
|
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
|
|
+ :rtype: str
|
|
|
"""
|
|
|
if config is None:
|
|
|
config = AddConfig()
|
|
@@ -177,39 +193,62 @@ class EmbedChain(JSONSerializable):
|
|
|
|
|
|
return source_id
|
|
|
|
|
|
- def add_local(self, source, data_type=None, metadata=None, config: AddConfig = None):
|
|
|
+ def add_local(
|
|
|
+ self,
|
|
|
+ source: Any,
|
|
|
+ data_type: Optional[DataType] = None,
|
|
|
+ metadata: Optional[Dict[str, Any]] = None,
|
|
|
+ config: Optional[AddConfig] = None,
|
|
|
+ ):
|
|
|
"""
|
|
|
- Warning:
|
|
|
- This method is deprecated and will be removed in future versions. Use `add` instead.
|
|
|
-
|
|
|
Adds the data from the given URL to the vector db.
|
|
|
Loads the data, chunks it, create embedding for each chunk
|
|
|
and then stores the embedding to vector database.
|
|
|
|
|
|
+ Warning:
|
|
|
+ This method is deprecated and will be removed in future versions. Use `add` instead.
|
|
|
+
|
|
|
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
|
|
|
- :param data_type: Optional. Automatically detected, but can be forced with this argument.
|
|
|
- The type of the data to add.
|
|
|
- :param metadata: Optional. Metadata associated with the data source.
|
|
|
- :param config: Optional. The `AddConfig` instance to use as configuration
|
|
|
- options.
|
|
|
- :return: md5-hash of the source, in hexadecimal representation.
|
|
|
+ :type source: Any
|
|
|
+ :param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
|
|
|
+ defaults to None
|
|
|
+ :type data_type: Optional[DataType], optional
|
|
|
+ :param metadata: Metadata associated with the data source., defaults to None
|
|
|
+ :type metadata: Optional[Dict[str, Any]], optional
|
|
|
+ :param config: The `AddConfig` instance to use as configuration options., defaults to None
|
|
|
+ :type config: Optional[AddConfig], optional
|
|
|
+ :raises ValueError: Invalid data type
|
|
|
+ :return: source_id, a md5-hash of the source, in hexadecimal representation.
|
|
|
+ :rtype: str
|
|
|
"""
|
|
|
logging.warning(
|
|
|
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
|
|
|
)
|
|
|
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
|
|
|
|
|
|
- def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None, source_id=None):
|
|
|
- """
|
|
|
- Loads the data from the given URL, chunks it, and adds it to database.
|
|
|
+ def load_and_embed(
|
|
|
+ self,
|
|
|
+ loader: BaseLoader,
|
|
|
+ chunker: BaseChunker,
|
|
|
+ src: Any,
|
|
|
+ metadata: Optional[Dict[str, Any]] = None,
|
|
|
+ source_id: Optional[str] = None,
|
|
|
+ ) -> Tuple[List[str], Dict[str, Any], List[str], int]:
|
|
|
+ """The loader to use to load the data.
|
|
|
|
|
|
:param loader: The loader to use to load the data.
|
|
|
+ :type loader: BaseLoader
|
|
|
:param chunker: The chunker to use to chunk the data.
|
|
|
- :param src: The data to be handled by the loader. Can be a URL for
|
|
|
- remote sources or local content for local loaders.
|
|
|
- :param metadata: Optional. Metadata associated with the data source.
|
|
|
- :param source_id: Hexadecimal hash of the source.
|
|
|
+ :type chunker: BaseChunker
|
|
|
+ :param src: The data to be handled by the loader.
|
|
|
+ Can be a URL for remote sources or local content for local loaders.
|
|
|
+ :type src: Any
|
|
|
+ :param metadata: Metadata associated with the data source., defaults to None
|
|
|
+ :type metadata: Dict[str, Any], optional
|
|
|
+ :param source_id: Hexadecimal hash of the source., defaults to None
|
|
|
+ :type source_id: str, optional
|
|
|
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
|
|
+ :rtype: Tuple[List[str], Dict[str, Any], List[str], int]
|
|
|
"""
|
|
|
embeddings_data = chunker.create_chunks(loader, src)
|
|
|
|
|
@@ -264,25 +303,19 @@ class EmbedChain(JSONSerializable):
|
|
|
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
|
|
return list(documents), metadatas, ids, count_new_chunks
|
|
|
|
|
|
- def _format_result(self, results):
|
|
|
- return [
|
|
|
- (Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
|
|
- for result in zip(
|
|
|
- results["documents"][0],
|
|
|
- results["metadatas"][0],
|
|
|
- results["distances"][0],
|
|
|
- )
|
|
|
- ]
|
|
|
-
|
|
|
- def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None):
|
|
|
+ def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
|
|
|
"""
|
|
|
Queries the vector database based on the given input query.
|
|
|
Gets relevant doc based on the query
|
|
|
|
|
|
:param input_query: The query to use.
|
|
|
- :param config: The query configuration.
|
|
|
- :param where: Optional. A dictionary of key-value pairs to filter the database results.
|
|
|
- :return: The content of the document that matched your query.
|
|
|
+ :type input_query: str
|
|
|
+ :param config: The query configuration, defaults to None
|
|
|
+ :type config: Optional[BaseLlmConfig], optional
|
|
|
+ :param where: A dictionary of key-value pairs to filter the database results, defaults to None
|
|
|
+ :type where: _type_, optional
|
|
|
+ :return: List of contents of the document that matched your query
|
|
|
+ :rtype: List[str]
|
|
|
"""
|
|
|
query_config = config or self.llm.config
|
|
|
|
|
@@ -304,23 +337,24 @@ class EmbedChain(JSONSerializable):
|
|
|
|
|
|
return contents
|
|
|
|
|
|
- def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
|
|
+ def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
|
|
|
"""
|
|
|
Queries the vector database based on the given input query.
|
|
|
Gets relevant doc based on the query and then passes it to an
|
|
|
LLM as context to get the answer.
|
|
|
|
|
|
:param input_query: The query to use.
|
|
|
- :param config: Optional. The `LlmConfig` instance to use as configuration options.
|
|
|
- This is used for one method call. To persistently use a config, declare it during app init.
|
|
|
- :param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
|
|
- the LLM. The purpose is to test the prompt, not the response.
|
|
|
- You can use it to test your prompt, including the context provided
|
|
|
- by the vector database's doc retrieval.
|
|
|
- The only thing the dry run does not consider is the cut-off due to
|
|
|
- the `max_tokens` parameter.
|
|
|
- :param where: Optional. A dictionary of key-value pairs to filter the database results.
|
|
|
- :return: The answer to the query.
|
|
|
+ :type input_query: str
|
|
|
+ :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
|
|
|
+ To persistently use a config, declare it during app init., defaults to None
|
|
|
+ :type config: Optional[BaseLlmConfig], optional
|
|
|
+ :param dry_run: A dry run does everything except send the resulting prompt to
|
|
|
+ the LLM. The purpose is to test the prompt, not the response., defaults to False
|
|
|
+ :type dry_run: bool, optional
|
|
|
+ :param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
|
|
+ :type where: Optional[Dict[str, str]], optional
|
|
|
+ :return: The answer to the query or the dry run result
|
|
|
+ :rtype: str
|
|
|
"""
|
|
|
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
|
|
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
|
@@ -331,24 +365,32 @@ class EmbedChain(JSONSerializable):
|
|
|
|
|
|
return answer
|
|
|
|
|
|
- def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
|
|
+ def chat(
|
|
|
+ self,
|
|
|
+ input_query: str,
|
|
|
+ config: Optional[BaseLlmConfig] = None,
|
|
|
+ dry_run=False,
|
|
|
+ where: Optional[Dict[str, str]] = None,
|
|
|
+ ) -> str:
|
|
|
"""
|
|
|
Queries the vector database on the given input query.
|
|
|
Gets relevant doc based on the query and then passes it to an
|
|
|
LLM as context to get the answer.
|
|
|
|
|
|
Maintains the whole conversation in memory.
|
|
|
+
|
|
|
:param input_query: The query to use.
|
|
|
- :param config: Optional. The `LlmConfig` instance to use as configuration options.
|
|
|
- This is used for one method call. To persistently use a config, declare it during app init.
|
|
|
- :param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
|
|
- the LLM. The purpose is to test the prompt, not the response.
|
|
|
- You can use it to test your prompt, including the context provided
|
|
|
- by the vector database's doc retrieval.
|
|
|
- The only thing the dry run does not consider is the cut-off due to
|
|
|
- the `max_tokens` parameter.
|
|
|
- :param where: Optional. A dictionary of key-value pairs to filter the database results.
|
|
|
- :return: The answer to the query.
|
|
|
+ :type input_query: str
|
|
|
+ :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
|
|
|
+ To persistently use a config, declare it during app init., defaults to None
|
|
|
+ :type config: Optional[BaseLlmConfig], optional
|
|
|
+ :param dry_run: A dry run does everything except send the resulting prompt to
|
|
|
+ the LLM. The purpose is to test the prompt, not the response., defaults to False
|
|
|
+ :type dry_run: bool, optional
|
|
|
+ :param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
|
|
+ :type where: Optional[Dict[str, str]], optional
|
|
|
+ :return: The answer to the query or the dry run result
|
|
|
+ :rtype: str
|
|
|
"""
|
|
|
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
|
|
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
|
@@ -359,15 +401,18 @@ class EmbedChain(JSONSerializable):
|
|
|
|
|
|
return answer
|
|
|
|
|
|
- def set_collection(self, collection_name):
|
|
|
+ def set_collection_name(self, name: str):
|
|
|
"""
|
|
|
- Set the collection to use.
|
|
|
+ Set the name of the collection. A collection is an isolated space for vectors.
|
|
|
+
|
|
|
+ Using `app.db.set_collection_name` method is preferred to this.
|
|
|
|
|
|
- :param collection_name: The name of the collection to use.
|
|
|
+ :param name: Name of the collection.
|
|
|
+ :type name: str
|
|
|
"""
|
|
|
- self.db.set_collection_name(collection_name)
|
|
|
+ self.db.set_collection_name(name)
|
|
|
# Create the collection if it does not exist
|
|
|
- self.db._get_or_create_collection(collection_name)
|
|
|
+ self.db._get_or_create_collection(name)
|
|
|
# TODO: Check whether it is necessary to assign to the `self.collection` attribute,
|
|
|
# since the main purpose is the creation.
|
|
|
|
|
@@ -378,8 +423,9 @@ class EmbedChain(JSONSerializable):
|
|
|
DEPRECATED IN FAVOR OF `db.count()`
|
|
|
|
|
|
:return: The number of embeddings.
|
|
|
+ :rtype: int
|
|
|
"""
|
|
|
- logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.")
|
|
|
+ logging.warning("DEPRECATION WARNING: Please use `app.db.count()` instead of `app.count()`.")
|
|
|
return self.db.count()
|
|
|
|
|
|
def reset(self):
|
|
@@ -393,11 +439,14 @@ class EmbedChain(JSONSerializable):
|
|
|
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
|
|
|
thread_telemetry.start()
|
|
|
|
|
|
- logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.")
|
|
|
+ logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
|
|
|
self.db.reset()
|
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
|
|
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
|
|
|
+ """
|
|
|
+ Send telemetry event to the embedchain server. This is anonymous. It can be toggled off in `AppConfig`.
|
|
|
+ """
|
|
|
if not self.config.collect_metrics:
|
|
|
return
|
|
|
|