Przeglądaj źródła

docs: update docstrings (#565)

cachho 1 rok temu
rodzic
commit
1ac8aef4de

+ 19 - 6
embedchain/apps/App.py

@@ -12,12 +12,13 @@ from embedchain.vectordb.chroma_db import ChromaDB
 @register_deserializable
 class App(EmbedChain):
     """
-    The EmbedChain app.
-    Has two functions: add and query.
+    The EmbedChain app in it's simplest and most straightforward form.
+    An opinionated choice of LLM, vector database and embedding model.
 
-    adds(data_type, url): adds the data from the given URL to the vector db.
+    Methods:
+    add(source, data_type): adds the data from the given URL to the vector db.
     query(query): finds answer to the given query using vector database and LLM.
-    dry_run(query): test your prompt without consuming tokens.
+    chat(query): finds answer to the given query using vector database and LLM, with conversation history.
     """
 
     def __init__(
@@ -28,8 +29,20 @@ class App(EmbedChain):
         system_prompt: Optional[str] = None,
     ):
         """
-        :param config: AppConfig instance to load as configuration. Optional.
-        :param system_prompt: System prompt string. Optional.
+        Initialize a new `CustomApp` instance. You only have a few choices to make.
+
+        :param config: Config for the app instance.
+        This is the most basic configuration, that does not fall into the LLM, database or embedder category,
+        defaults to None
+        :type config: AppConfig, optional
+        :param llm_config: Allows you to configure the LLM, e.g. how many documents to return,
+        example: `from embedchain.config import LlmConfig`, defaults to None
+        :type llm_config: BaseLlmConfig, optional
+        :param chromadb_config: Allows you to configure the vector database,
+        example: `from embedchain.config import ChromaDbConfig`, defaults to None
+        :type chromadb_config: Optional[ChromaDbConfig], optional
+        :param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
+        :type system_prompt: Optional[str], optional
         """
         if config is None:
             config = AppConfig()

+ 24 - 8
embedchain/apps/CustomApp.py

@@ -11,26 +11,42 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
 @register_deserializable
 class CustomApp(EmbedChain):
     """
-    The custom EmbedChain app.
-    Has two functions: add and query.
+    Embedchain's custom app allows for most flexibility.
 
-    adds(data_type, url): adds the data from the given URL to the vector db.
+    You can craft your own mix of various LLMs, vector databases and embedding model/functions.
+
+    Methods:
+    add(source, data_type): adds the data from the given URL to the vector db.
     query(query): finds answer to the given query using vector database and LLM.
-    dry_run(query): test your prompt without consuming tokens.
+    chat(query): finds answer to the given query using vector database and LLM, with conversation history.
     """
 
     def __init__(
         self,
-        config: CustomAppConfig = None,
+        config: Optional[CustomAppConfig] = None,
         llm: BaseLlm = None,
         db: BaseVectorDB = None,
         embedder: BaseEmbedder = None,
         system_prompt: Optional[str] = None,
     ):
         """
-        :param config: Optional. `CustomAppConfig` instance to load as configuration.
-        :raises ValueError: Config must be provided for custom app
-        :param system_prompt: Optional. System prompt string.
+        Initialize a new `CustomApp` instance. You have to choose a LLM, database and embedder.
+
+        :param config: Config for the app instance. This is the most basic configuration,
+        that does not fall into the LLM, database or embedder category, defaults to None
+        :type config: Optional[CustomAppConfig], optional
+        :param llm: LLM Class instance. example: `from embedchain.llm.openai_llm import OpenAiLlm`, defaults to None
+        :type llm: BaseLlm
+        :param db: The database to use for storing and retrieving embeddings,
+        example: `from embedchain.vectordb.chroma_db import ChromaDb`, defaults to None
+        :type db: BaseVectorDB
+        :param embedder: The embedder (embedding model and function) use to calculate embeddings.
+        example: `from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder`, defaults to None
+        :type embedder: BaseEmbedder
+        :param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
+        :type system_prompt: Optional[str], optional
+        :raises ValueError: LLM, database or embedder has not been defined.
+        :raises TypeError: LLM, database or embedder is not a valid class instance.
         """
         # Config is not required, it has a default
         if config is None:

+ 3 - 2
embedchain/apps/Llama2App.py

@@ -12,10 +12,11 @@ from embedchain.vectordb.chroma_db import ChromaDB
 class Llama2App(CustomApp):
     """
     The EmbedChain Llama2App class.
-    Has two functions: add and query.
 
-    adds(data_type, url): adds the data from the given URL to the vector db.
+    Methods:
+    add(source, data_type): adds the data from the given URL to the vector db.
     query(query): finds answer to the given query using vector database and LLM.
+    chat(query): finds answer to the given query using vector database and LLM, with conversation history.
     """
 
     def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):

+ 35 - 14
embedchain/apps/OpenSourceApp.py

@@ -15,43 +15,64 @@ gpt4all_model = None
 @register_deserializable
 class OpenSourceApp(EmbedChain):
     """
-    The OpenSource app.
-    Same as App, but uses an open source embedding model and LLM.
+    The embedchain Open Source App.
+    Comes preconfigured with the best open source LLM, embedding model, database.
 
-    Has two function: add and query.
-
-    adds(data_type, url): adds the data from the given URL to the vector db.
+    Methods:
+    add(source, data_type): adds the data from the given URL to the vector db.
     query(query): finds answer to the given query using vector database and LLM.
+    chat(query): finds answer to the given query using vector database and LLM, with conversation history.
     """
 
     def __init__(
         self,
         config: OpenSourceAppConfig = None,
+        llm_config: BaseLlmConfig = None,
         chromadb_config: Optional[ChromaDbConfig] = None,
         system_prompt: Optional[str] = None,
     ):
         """
-        :param config: OpenSourceAppConfig instance to load as configuration. Optional.
-        `ef` defaults to open source.
-        :param system_prompt: System prompt string. Optional.
+        Initialize a new `CustomApp` instance.
+        Since it's opinionated you don't have to choose a LLM, database and embedder.
+        However, you can configure those.
+
+        :param config: Config for the app instance. This is the most basic configuration,
+        that does not fall into the LLM, database or embedder category, defaults to None
+        :type config: OpenSourceAppConfig, optional
+        :param llm_config: Allows you to configure the LLM, e.g. how many documents to return.
+        example: `from embedchain.config import LlmConfig`, defaults to None
+        :type llm_config: BaseLlmConfig, optional
+        :param chromadb_config: Allows you to configure the open source database,
+        example: `from embedchain.config import ChromaDbConfig`, defaults to None
+        :type chromadb_config: Optional[ChromaDbConfig], optional
+        :param system_prompt: System prompt that will be provided to the LLM as such.
+        Please don't use for the time being, as it's not supported., defaults to None
+        :type system_prompt: Optional[str], optional
+        :raises TypeError: `OpenSourceAppConfig` or `LlmConfig` invalid.
         """
         logging.info("Loading open source embedding model. This may take some time...")  # noqa:E501
         if not config:
             config = OpenSourceAppConfig()
 
         if not isinstance(config, OpenSourceAppConfig):
-            raise ValueError(
+            raise TypeError(
                 "OpenSourceApp needs a OpenSourceAppConfig passed to it. "
                 "You can import it with `from embedchain.config import OpenSourceAppConfig`"
             )
 
-        if not config.model:
-            raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?")
-
-        logging.info("Successfully loaded open source embedding model.")
+        if not llm_config:
+            llm_config = BaseLlmConfig(model="orca-mini-3b.ggmlv3.q4_0.bin")
+        elif not isinstance(llm_config, BaseLlmConfig):
+            raise TypeError(
+                "The LlmConfig passed to OpenSourceApp is invalid. "
+                "You can import it with `from embedchain.config import LlmConfig`"
+            )
+        elif not llm_config.model:
+            llm_config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
 
-        llm = GPT4ALLLlm(config=BaseLlmConfig(model="orca-mini-3b.ggmlv3.q4_0.bin"))
+        llm = GPT4ALLLlm(config=llm_config)
         embedder = GPT4AllEmbedder(config=BaseEmbedderConfig(model="all-MiniLM-L6-v2"))
+        logging.error("Successfully loaded open source embedding model.")
         database = ChromaDB(config=chromadb_config)
 
         super().__init__(config, llm=llm, db=database, embedder=embedder, system_prompt=system_prompt)

+ 14 - 4
embedchain/apps/PersonApp.py

@@ -19,7 +19,14 @@ class EmbedChainPersonApp:
     :param config: BaseAppConfig instance to load as configuration.
     """
 
-    def __init__(self, person, config: BaseAppConfig = None):
+    def __init__(self, person: str, config: BaseAppConfig = None):
+        """Initialize a new person app
+
+        :param person: Name of the person that's imitated.
+        :type person: str
+        :param config: Configuration class instance, defaults to None
+        :type config: BaseAppConfig, optional
+        """
         self.person = person
         self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style."  # noqa:E501
         super().__init__(config)
@@ -30,9 +37,12 @@ class EmbedChainPersonApp:
         if yes it adds the person prompt to it and return the updated config
         else it creates a config object with the default prompt added to the person prompt
 
-        :param default_prompt: it is the default prompt for query or chat methods
-        :param config: Optional. The `ChatConfig` instance to use as
-        configuration options.
+        :param default_prompt:  it is the default prompt for query or chat methods
+        :type default_prompt: str
+        :param config: _description_, defaults to None
+        :type config: BaseLlmConfig, optional
+        :return: The `ChatConfig` instance to use as configuration options.
+        :rtype: _type_
         """
         template = Template(self.person_prompt + " " + default_prompt)
 

+ 23 - 4
embedchain/bots/base.py

@@ -1,3 +1,5 @@
+from typing import Any
+
 from embedchain import CustomApp
 from embedchain.config import AddConfig, CustomAppConfig, LlmConfig
 from embedchain.embedder.openai_embedder import OpenAiEmbedder
@@ -12,13 +14,30 @@ class BaseBot(JSONSerializable):
     def __init__(self):
         self.app = CustomApp(config=CustomAppConfig(), llm=OpenAiLlm(), db=ChromaDB(), embedder=OpenAiEmbedder())
 
-    def add(self, data, config: AddConfig = None):
-        """Add data to the bot"""
+    def add(self, data: Any, config: AddConfig = None):
+        """
+        Add data to the bot (to the vector database).
+        Auto-dectects type only, so some data types might not be usable.
+
+        :param data: data to embed
+        :type data: Any
+        :param config: configuration class instance, defaults to None
+        :type config: AddConfig, optional
+        """
         config = config if config else AddConfig()
         self.app.add(data, config=config)
 
-    def query(self, query, config: LlmConfig = None):
-        """Query bot"""
+    def query(self, query: str, config: LlmConfig = None) -> str:
+        """
+        Query the bot
+
+        :param query: the user query
+        :type query: str
+        :param config: configuration class instance, defaults to None
+        :type config: LlmConfig, optional
+        :return: Answer
+        :rtype: str
+        """
         config = config
         return self.app.query(query, config=config)
 

+ 8 - 0
embedchain/config/AddConfig.py

@@ -42,5 +42,13 @@ class AddConfig(BaseConfig):
         chunker: Optional[ChunkerConfig] = None,
         loader: Optional[LoaderConfig] = None,
     ):
+        """
+        Initializes a configuration class instance for the `add` method.
+
+        :param chunker: Chunker config, defaults to None
+        :type chunker: Optional[ChunkerConfig], optional
+        :param loader: Loader config, defaults to None
+        :type loader: Optional[LoaderConfig], optional
+        """
         self.loader = loader
         self.chunker = chunker

+ 9 - 1
embedchain/config/BaseConfig.py

@@ -1,3 +1,5 @@
+from typing import Any, Dict
+
 from embedchain.helper_classes.json_serializable import JSONSerializable
 
 
@@ -7,7 +9,13 @@ class BaseConfig(JSONSerializable):
     """
 
     def __init__(self):
+        """Initializes a configuration class for a class."""
         pass
 
-    def as_dict(self):
+    def as_dict(self) -> Dict[str, Any]:
+        """Return config object as a dict
+
+        :return: config object as dict
+        :rtype: Dict[str, Any]
+        """
         return vars(self)

+ 14 - 6
embedchain/config/apps/AppConfig.py

@@ -13,15 +13,23 @@ class AppConfig(BaseAppConfig):
 
     def __init__(
         self,
-        log_level=None,
-        id=None,
+        log_level: str = "WARNING",
+        id: Optional[str] = None,
         collect_metrics: Optional[bool] = None,
         collection_name: Optional[str] = None,
     ):
         """
-        :param log_level: Optional. (String) Debug level
-        ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
-        :param id: Optional. ID of the app. Document metadata will have this id.
-        :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
+        Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
+        Most of the configuration is done in the `App` class itself.
+
+        :param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
+        :type log_level: str, optional
+        :param id: ID of the app. Document metadata will have this id., defaults to None
+        :type id: Optional[str], optional
+        :param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
+        :type collect_metrics: Optional[bool], optional
+        :param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
+        defaults to None
+        :type collection_name: Optional[str], optional
         """
         super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)

+ 17 - 12
embedchain/config/apps/BaseAppConfig.py

@@ -13,23 +13,28 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
 
     def __init__(
         self,
-        log_level=None,
+        log_level: str = "WARNING",
         db: Optional[BaseVectorDB] = None,
-        id=None,
+        id: Optional[str] = None,
         collect_metrics: bool = True,
         collection_name: Optional[str] = None,
     ):
         """
-        :param log_level: Optional. (String) Debug level
-        ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
-        :param db: Optional. (Vector) database instance to use for embeddings. Deprecated in favor of app(..., db).
-        :param id: Optional. ID of the app. Document metadata will have this id.
-        :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
-        :param db_type: Optional. Initializes a default vector database of the given type.
-        Using the `db` argument is preferred.
-        :param es_config: Optional. elasticsearch database config to be used for connection
-        :param collection_name: Optional. Default collection name.
-        It's recommended to use app.set_collection_name() instead.
+        Initializes a configuration class instance for an App.
+        Most of the configuration is done in the `App` class itself.
+
+        :param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
+        :type log_level: str, optional
+        :param db: A database class. It is recommended to set this directly in the `App` class, not this config,
+        defaults to None
+        :type db: Optional[BaseVectorDB], optional
+        :param id: ID of the app. Document metadata will have this id., defaults to None
+        :type id: Optional[str], optional
+        :param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
+        :type collect_metrics: Optional[bool], optional
+        :param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
+        defaults to None
+        :type collection_name: Optional[str], optional
         """
         self._setup_logging(log_level)
         self.id = id

+ 19 - 13
embedchain/config/apps/CustomAppConfig.py

@@ -3,6 +3,7 @@ from typing import Optional
 from dotenv import load_dotenv
 
 from embedchain.helper_classes.json_serializable import register_deserializable
+from embedchain.vectordb.base_vector_db import BaseVectorDB
 
 from .BaseAppConfig import BaseAppConfig
 
@@ -17,24 +18,29 @@ class CustomAppConfig(BaseAppConfig):
 
     def __init__(
         self,
-        log_level=None,
-        db=None,
-        id=None,
+        log_level: str = "WARNING",
+        db: Optional[BaseVectorDB] = None,
+        id: Optional[str] = None,
         collect_metrics: Optional[bool] = None,
         collection_name: Optional[str] = None,
     ):
         """
-        :param log_level: Optional. (String) Debug level
-        ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
-        :param db: Optional. (Vector) database to use for embeddings.
-        :param id: Optional. ID of the app. Document metadata will have this id.
-        :param provider: Optional. (Providers): LLM Provider to use.
-        :param open_source_app_config: Optional. Config instance needed for open source apps.
-        :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
-        :param collection_name: Optional. Default collection name.
-        It's recommended to use app.set_collection_name() instead.
+        Initializes a configuration class instance for an Custom App.
+        Most of the configuration is done in the `CustomApp` class itself.
+
+        :param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
+        :type log_level: str, optional
+        :param db: A database class. It is recommended to set this directly in the `CustomApp` class, not this config,
+        defaults to None
+        :type db: Optional[BaseVectorDB], optional
+        :param id: ID of the app. Document metadata will have this id., defaults to None
+        :type id: Optional[str], optional
+        :param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
+        :type collect_metrics: Optional[bool], optional
+        :param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
+        defaults to None
+        :type collection_name: Optional[str], optional
         """
-
         super().__init__(
             log_level=log_level, db=db, id=id, collect_metrics=collect_metrics, collection_name=collection_name
         )

+ 17 - 11
embedchain/config/apps/OpenSourceAppConfig.py

@@ -13,21 +13,27 @@ class OpenSourceAppConfig(BaseAppConfig):
 
     def __init__(
         self,
-        log_level=None,
-        id=None,
+        log_level: str = "WARNING",
+        id: Optional[str] = None,
         collect_metrics: Optional[bool] = None,
-        model=None,
+        model: str = "orca-mini-3b.ggmlv3.q4_0.bin",
         collection_name: Optional[str] = None,
     ):
         """
-        :param log_level: Optional. (String) Debug level
-        ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
-        :param id: Optional. ID of the app. Document metadata will have this id.
-        :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
-        :param model: Optional. GPT4ALL uses the model to instantiate the class.
-        So unlike `App`, it has to be provided before querying.
-        :param collection_name: Optional. Default collection name.
-        It's recommended to use app.db.set_collection_name() instead.
+        Initializes a configuration class instance for an Open Source App.
+
+        :param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
+        :type log_level: str, optional
+        :param id: ID of the app. Document metadata will have this id., defaults to None
+        :type id: Optional[str], optional
+        :param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
+        :type collect_metrics: Optional[bool], optional
+        :param model: GPT4ALL uses the model to instantiate the class.
+        Unlike `App`, it has to be provided before querying, defaults to "orca-mini-3b.ggmlv3.q4_0.bin"
+        :type model: str, optional
+        :param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
+        defaults to None
+        :type collection_name: Optional[str], optional
         """
         self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin"
 

+ 8 - 0
embedchain/config/embedder/BaseEmbedderConfig.py

@@ -6,5 +6,13 @@ from embedchain.helper_classes.json_serializable import register_deserializable
 @register_deserializable
 class BaseEmbedderConfig:
     def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None):
+        """
+        Initialize a new instance of an embedder config class.
+
+        :param model: model name of the llm embedding model (not applicable to all providers), defaults to None
+        :type model: Optional[str], optional
+        :param deployment_name: deployment name for llm embedding model, defaults to None
+        :type deployment_name: Optional[str], optional
+        """
         self.model = model
         self.deployment_name = deployment_name

+ 53 - 41
embedchain/config/llm/base_llm_config.py

@@ -1,6 +1,6 @@
 import re
 from string import Template
-from typing import Optional
+from typing import Any, Dict, Optional
 
 from embedchain.config.BaseConfig import BaseConfig
 from embedchain.helper_classes.json_serializable import register_deserializable
@@ -57,51 +57,59 @@ class BaseLlmConfig(BaseConfig):
 
     def __init__(
         self,
-        number_documents=None,
-        template: Template = None,
-        model=None,
-        temperature=None,
-        max_tokens=None,
-        top_p=None,
+        number_documents: int = 1,
+        template: Optional[Template] = None,
+        model: Optional[str] = None,
+        temperature: float = 0,
+        max_tokens: int = 1000,
+        top_p: float = 1,
         stream: bool = False,
-        deployment_name=None,
+        deployment_name: Optional[str] = None,
         system_prompt: Optional[str] = None,
-        where=None,
+        where: Dict[str, Any] = None,
     ):
         """
-        Initializes the QueryConfig instance.
-
-        :param number_documents: Number of documents to pull from the database as
-        context.
-        :param template: Optional. The `Template` instance to use as a template for
-        prompt.
-        :param model: Optional. Controls the OpenAI model used.
-        :param temperature: Optional. Controls the randomness of the model's output.
-        Higher values (closer to 1) make output more random, lower values make it more
-        deterministic.
-        :param max_tokens: Optional. Controls how many tokens are generated.
-        :param top_p: Optional. Controls the diversity of words. Higher values
-        (closer to 1) make word selection more diverse, lower values make words less
-        diverse.
-        :param stream: Optional. Control if response is streamed back to user
-        :param deployment_name: t.b.a.
-        :param system_prompt: Optional. System prompt string.
-        :param where: Optional. A dictionary of key-value pairs to filter the database results.
+        Initializes a configuration class instance for the LLM.
+
+        Takes the place of the former `QueryConfig` or `ChatConfig`.
+        Use `LlmConfig` as an alias to `BaseLlmConfig`.
+
+        :param number_documents:  Number of documents to pull from the database as
+        context, defaults to 1
+        :type number_documents: int, optional
+        :param template:  The `Template` instance to use as a template for
+        prompt, defaults to None
+        :type template: Optional[Template], optional
+        :param model: Controls the OpenAI model used, defaults to None
+        :type model: Optional[str], optional
+        :param temperature:  Controls the randomness of the model's output.
+        Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
+        :type temperature: float, optional
+        :param max_tokens: Controls how many tokens are generated, defaults to 1000
+        :type max_tokens: int, optional
+        :param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
+        defaults to 1
+        :type top_p: float, optional
+        :param stream: Control if response is streamed back to user, defaults to False
+        :type stream: bool, optional
+        :param deployment_name: t.b.a., defaults to None
+        :type deployment_name: Optional[str], optional
+        :param system_prompt: System prompt string, defaults to None
+        :type system_prompt: Optional[str], optional
+        :param where: A dictionary of key-value pairs to filter the database results., defaults to None
+        :type where: Dict[str, Any], optional
         :raises ValueError: If the template is not valid as template should
-        contain $context and $query (and optionally $history).
+        contain $context and $query (and optionally $history)
+        :raises ValueError: Stream is not boolean
         """
-        if number_documents is None:
-            self.number_documents = 1
-        else:
-            self.number_documents = number_documents
-
         if template is None:
             template = DEFAULT_PROMPT_TEMPLATE
 
-        self.temperature = temperature if temperature else 0
-        self.max_tokens = max_tokens if max_tokens else 1000
+        self.number_documents = number_documents
+        self.temperature = temperature
+        self.max_tokens = max_tokens
         self.model = model
-        self.top_p = top_p if top_p else 1
+        self.top_p = top_p
         self.deployment_name = deployment_name
         self.system_prompt = system_prompt
 
@@ -115,20 +123,24 @@ class BaseLlmConfig(BaseConfig):
         self.stream = stream
         self.where = where
 
-    def validate_template(self, template: Template):
+    def validate_template(self, template: Template) -> bool:
         """
         validate the template
 
         :param template: the template to validate
-        :return: Boolean, valid (true) or invalid (false)
+        :type template: Template
+        :return: valid (true) or invalid (false)
+        :rtype: bool
         """
         return re.search(query_re, template.template) and re.search(context_re, template.template)
 
-    def _validate_template_history(self, template: Template):
+    def _validate_template_history(self, template: Template) -> bool:
         """
-        validate the history template for history
+        validate the template with history
 
         :param template: the template to validate
-        :return: Boolean, valid (true) or invalid (false)
+        :type template: Template
+        :return: valid (true) or invalid (false)
+        :rtype: bool
         """
         return re.search(history_re, template.template)

+ 14 - 2
embedchain/config/vectordbs/BaseVectorDbConfig.py

@@ -7,11 +7,23 @@ class BaseVectorDbConfig(BaseConfig):
     def __init__(
         self,
         collection_name: Optional[str] = None,
-        dir: Optional[str] = None,
+        dir: str = "db",
         host: Optional[str] = None,
         port: Optional[str] = None,
     ):
+        """
+        Initializes a configuration class instance for the vector database.
+
+        :param collection_name: Default name for the collection, defaults to None
+        :type collection_name: Optional[str], optional
+        :param dir: Path to the database directory, where the database is stored, defaults to "db"
+        :type dir: str, optional
+        :param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
+        :type host: Optional[str], optional
+        :param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
+        :type port: Optional[str], optional
+        """
         self.collection_name = collection_name or "embedchain_store"
-        self.dir = dir or "db"
+        self.dir = dir
         self.host = host
         self.port = port

+ 14 - 0
embedchain/config/vectordbs/ChromaDbConfig.py

@@ -14,6 +14,20 @@ class ChromaDbConfig(BaseVectorDbConfig):
         port: Optional[str] = None,
         chroma_settings: Optional[dict] = None,
     ):
+        """
+        Initializes a configuration class instance for ChromaDB.
+
+        :param collection_name: Default name for the collection, defaults to None
+        :type collection_name: Optional[str], optional
+        :param dir: Path to the database directory, where the database is stored, defaults to None
+        :type dir: Optional[str], optional
+        :param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
+        :type host: Optional[str], optional
+        :param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
+        :type port: Optional[str], optional
+        :param chroma_settings: Chroma settings dict, defaults to None
+        :type chroma_settings: Optional[dict], optional
+        """
         """
         :param chroma_settings: Optional. Chroma settings for connection.
         """

+ 9 - 2
embedchain/config/vectordbs/ElasticsearchDBConfig.py

@@ -14,9 +14,16 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
         **ES_EXTRA_PARAMS: Dict[str, any],
     ):
         """
-        Config to initialize an elasticsearch client.
-        :param es_url. elasticsearch url or list of nodes url to be used for connection
+        Initializes a configuration class instance for an Elasticsearch client.
+
+        :param collection_name: Default name for the collection, defaults to None
+        :type collection_name: Optional[str], optional
+        :param dir: Path to the database directory, where the database is stored, defaults to None
+        :type dir: Optional[str], optional
+        :param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
+        :type es_url: Union[str, List[str]], optional
         :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
+        :type ES_EXTRA_PARAMS: Dict[str, Any], optional
         """
         # self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
         self.ES_URL = es_url

+ 30 - 12
embedchain/data_formatter/data_formatter.py

@@ -1,3 +1,4 @@
+from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.docs_site import DocsSiteChunker
 from embedchain.chunkers.docx_file import DocxFileChunker
 from embedchain.chunkers.notion import NotionChunker
@@ -8,7 +9,9 @@ from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.web_page import WebPageChunker
 from embedchain.chunkers.youtube_video import YoutubeVideoChunker
 from embedchain.config import AddConfig
+from embedchain.config.AddConfig import ChunkerConfig, LoaderConfig
 from embedchain.helper_classes.json_serializable import JSONSerializable
+from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.csv import CsvLoader
 from embedchain.loaders.docs_site_loader import DocsSiteLoader
 from embedchain.loaders.docx_file import DocxFileLoader
@@ -29,16 +32,28 @@ class DataFormatter(JSONSerializable):
     """
 
     def __init__(self, data_type: DataType, config: AddConfig):
-        self.loader = self._get_loader(data_type, config.loader)
-        self.chunker = self._get_chunker(data_type, config.chunker)
+        """
+        Initialize a dataformatter, set data type and chunker based on datatype.
+
+        :param data_type: The type of the data to load and chunk.
+        :type data_type: DataType
+        :param config: AddConfig instance with nested loader and chunker config attributes.
+        :type config: AddConfig
+        """
+        self.loader = self._get_loader(data_type=data_type, config=config.loader)
+        self.chunker = self._get_chunker(data_type=data_type, config=config.chunker)
 
-    def _get_loader(self, data_type: DataType, config):
+    def _get_loader(self, data_type: DataType, config: LoaderConfig) -> BaseLoader:
         """
         Returns the appropriate data loader for the given data type.
 
         :param data_type: The type of the data to load.
-        :return: The loader for the given data type.
+        :type data_type: DataType
+        :param config: Config to initialize the loader with.
+        :type config: LoaderConfig
         :raises ValueError: If an unsupported data type is provided.
+        :return: The loader for the given data type.
+        :rtype: BaseLoader
         """
         loaders = {
             DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
@@ -53,8 +68,8 @@ class DataFormatter(JSONSerializable):
         }
         lazy_loaders = {DataType.NOTION}
         if data_type in loaders:
-            loader_class = loaders[data_type]
-            loader = loader_class()
+            loader_class: type = loaders[data_type]
+            loader: BaseLoader = loader_class()
             return loader
         elif data_type in lazy_loaders:
             if data_type == DataType.NOTION:
@@ -66,13 +81,16 @@ class DataFormatter(JSONSerializable):
         else:
             raise ValueError(f"Unsupported data type: {data_type}")
 
-    def _get_chunker(self, data_type: DataType, config):
-        """
-        Returns the appropriate chunker for the given data type.
+    def _get_chunker(self, data_type: DataType, config: ChunkerConfig) -> BaseChunker:
+        """Returns the appropriate chunker for the given data type.
 
         :param data_type: The type of the data to chunk.
-        :return: The chunker for the given data type.
+        :type data_type: DataType
+        :param config: Config to initialize the chunker with.
+        :type config: ChunkerConfig
         :raises ValueError: If an unsupported data type is provided.
+        :return: The chunker for the given data type.
+        :rtype: BaseChunker
         """
         chunker_classes = {
             DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
@@ -87,8 +105,8 @@ class DataFormatter(JSONSerializable):
             DataType.CSV: TableChunker,
         }
         if data_type in chunker_classes:
-            chunker_class = chunker_classes[data_type]
-            chunker = chunker_class(config)
+            chunker_class: type = chunker_classes[data_type]
+            chunker: BaseChunker = chunker_class(config)
             chunker.set_data_type(data_type)
             return chunker
         else:

+ 121 - 72
embedchain/embedchain.py

@@ -6,11 +6,10 @@ import os
 import threading
 import uuid
 from pathlib import Path
-from typing import Dict, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 import requests
 from dotenv import load_dotenv
-from langchain.docstore.document import Document
 from tenacity import retry, stop_after_attempt, wait_fixed
 
 from embedchain.chunkers.base_chunker import BaseChunker
@@ -46,8 +45,17 @@ class EmbedChain(JSONSerializable):
         Initializes the EmbedChain instance, sets up a vector DB client and
         creates a collection.
 
-        :param config: BaseAppConfig instance to load as configuration.
-        :param system_prompt: Optional. System prompt string.
+        :param config: Configuration just for the app, not the db or llm or embedder.
+        :type config: BaseAppConfig
+        :param llm: Instance of the LLM you want to use.
+        :type llm: BaseLlm
+        :param db: Instance of the Database to use, defaults to None
+        :type db: BaseVectorDB, optional
+        :param embedder: instance of the embedder to use, defaults to None
+        :type embedder: BaseEmbedder, optional
+        :param system_prompt: System prompt to use in the llm query, defaults to None
+        :type system_prompt: Optional[str], optional
+        :raises ValueError: No database or embedder provided.
         """
 
         self.config = config
@@ -88,10 +96,13 @@ class EmbedChain(JSONSerializable):
         thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
         thread_telemetry.start()
 
-    def _load_or_generate_user_id(self):
+    def _load_or_generate_user_id(self) -> str:
         """
         Loads the user id from the config file if it exists, otherwise generates a new
         one and saves it to the config file.
+
+        :return: user id
+        :rtype: str
         """
         if not os.path.exists(CONFIG_DIR):
             os.makedirs(CONFIG_DIR)
@@ -110,9 +121,9 @@ class EmbedChain(JSONSerializable):
 
     def add(
         self,
-        source,
+        source: Any,
         data_type: Optional[DataType] = None,
-        metadata: Optional[Dict] = None,
+        metadata: Optional[Dict[str, Any]] = None,
         config: Optional[AddConfig] = None,
     ):
         """
@@ -121,12 +132,17 @@ class EmbedChain(JSONSerializable):
         and then stores the embedding to vector database.
 
         :param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
-        :param data_type: Optional. Automatically detected, but can be forced with this argument.
-        The type of the data to add.
-        :param metadata: Optional. Metadata associated with the data source.
-        :param config: Optional. The `AddConfig` instance to use as configuration
-        options.
+        :type source: Any
+        :param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
+        defaults to None
+        :type data_type: Optional[DataType], optional
+        :param metadata: Metadata associated with the data source., defaults to None
+        :type metadata: Optional[Dict[str, Any]], optional
+        :param config: The `AddConfig` instance to use as configuration options., defaults to None
+        :type config: Optional[AddConfig], optional
+        :raises ValueError: Invalid data type
         :return: source_id, a md5-hash of the source, in hexadecimal representation.
+        :rtype: str
         """
         if config is None:
             config = AddConfig()
@@ -177,39 +193,62 @@ class EmbedChain(JSONSerializable):
 
         return source_id
 
-    def add_local(self, source, data_type=None, metadata=None, config: AddConfig = None):
+    def add_local(
+        self,
+        source: Any,
+        data_type: Optional[DataType] = None,
+        metadata: Optional[Dict[str, Any]] = None,
+        config: Optional[AddConfig] = None,
+    ):
         """
-        Warning:
-            This method is deprecated and will be removed in future versions. Use `add` instead.
-
         Adds the data from the given URL to the vector db.
         Loads the data, chunks it, create embedding for each chunk
         and then stores the embedding to vector database.
 
+        Warning:
+            This method is deprecated and will be removed in future versions. Use `add` instead.
+
         :param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
-        :param data_type: Optional. Automatically detected, but can be forced with this argument.
-        The type of the data to add.
-        :param metadata: Optional. Metadata associated with the data source.
-        :param config: Optional. The `AddConfig` instance to use as configuration
-        options.
-        :return: md5-hash of the source, in hexadecimal representation.
+        :type source: Any
+        :param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
+        defaults to None
+        :type data_type: Optional[DataType], optional
+        :param metadata: Metadata associated with the data source., defaults to None
+        :type metadata: Optional[Dict[str, Any]], optional
+        :param config: The `AddConfig` instance to use as configuration options., defaults to None
+        :type config: Optional[AddConfig], optional
+        :raises ValueError: Invalid data type
+        :return: source_id, a md5-hash of the source, in hexadecimal representation.
+        :rtype: str
         """
         logging.warning(
             "The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files."  # noqa: E501
         )
         return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
 
-    def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None, source_id=None):
-        """
-        Loads the data from the given URL, chunks it, and adds it to database.
+    def load_and_embed(
+        self,
+        loader: BaseLoader,
+        chunker: BaseChunker,
+        src: Any,
+        metadata: Optional[Dict[str, Any]] = None,
+        source_id: Optional[str] = None,
+    ) -> Tuple[List[str], Dict[str, Any], List[str], int]:
+        """The loader to use to load the data.
 
         :param loader: The loader to use to load the data.
+        :type loader: BaseLoader
         :param chunker: The chunker to use to chunk the data.
-        :param src: The data to be handled by the loader. Can be a URL for
-        remote sources or local content for local loaders.
-        :param metadata: Optional. Metadata associated with the data source.
-        :param source_id: Hexadecimal hash of the source.
+        :type chunker: BaseChunker
+        :param src: The data to be handled by the loader.
+        Can be a URL for remote sources or local content for local loaders.
+        :type src: Any
+        :param metadata: Metadata associated with the data source., defaults to None
+        :type metadata: Dict[str, Any], optional
+        :param source_id: Hexadecimal hash of the source., defaults to None
+        :type source_id: str, optional
         :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
+        :rtype: Tuple[List[str], Dict[str, Any], List[str], int]
         """
         embeddings_data = chunker.create_chunks(loader, src)
 
@@ -264,25 +303,19 @@ class EmbedChain(JSONSerializable):
         print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
         return list(documents), metadatas, ids, count_new_chunks
 
-    def _format_result(self, results):
-        return [
-            (Document(page_content=result[0], metadata=result[1] or {}), result[2])
-            for result in zip(
-                results["documents"][0],
-                results["metadatas"][0],
-                results["distances"][0],
-            )
-        ]
-
-    def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None):
+    def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query
 
         :param input_query: The query to use.
-        :param config: The query configuration.
-        :param where: Optional. A dictionary of key-value pairs to filter the database results.
-        :return: The content of the document that matched your query.
+        :type input_query: str
+        :param config: The query configuration, defaults to None
+        :type config: Optional[BaseLlmConfig], optional
+        :param where: A dictionary of key-value pairs to filter the database results, defaults to None
+        :type where: _type_, optional
+        :return: List of contents of the document that matched your query
+        :rtype: List[str]
         """
         query_config = config or self.llm.config
 
@@ -304,23 +337,24 @@ class EmbedChain(JSONSerializable):
 
         return contents
 
-    def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
+    def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
         LLM as context to get the answer.
 
         :param input_query: The query to use.
-        :param config: Optional. The `LlmConfig` instance to use as configuration options.
-        This is used for one method call. To persistently use a config, declare it during app init.
-        :param dry_run: Optional. A dry run does everything except send the resulting prompt to
-        the LLM. The purpose is to test the prompt, not the response.
-        You can use it to test your prompt, including the context provided
-        by the vector database's doc retrieval.
-        The only thing the dry run does not consider is the cut-off due to
-        the `max_tokens` parameter.
-        :param where: Optional. A dictionary of key-value pairs to filter the database results.
-        :return: The answer to the query.
+        :type input_query: str
+        :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
+        To persistently use a config, declare it during app init., defaults to None
+        :type config: Optional[BaseLlmConfig], optional
+        :param dry_run: A dry run does everything except send the resulting prompt to
+        the LLM. The purpose is to test the prompt, not the response., defaults to False
+        :type dry_run: bool, optional
+        :param where: A dictionary of key-value pairs to filter the database results., defaults to None
+        :type where: Optional[Dict[str, str]], optional
+        :return: The answer to the query or the dry run result
+        :rtype: str
         """
         contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
         answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
@@ -331,24 +365,32 @@ class EmbedChain(JSONSerializable):
 
         return answer
 
-    def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
+    def chat(
+        self,
+        input_query: str,
+        config: Optional[BaseLlmConfig] = None,
+        dry_run=False,
+        where: Optional[Dict[str, str]] = None,
+    ) -> str:
         """
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
         LLM as context to get the answer.
 
         Maintains the whole conversation in memory.
+
         :param input_query: The query to use.
-        :param config: Optional. The `LlmConfig` instance to use as configuration options.
-        This is used for one method call. To persistently use a config, declare it during app init.
-        :param dry_run: Optional. A dry run does everything except send the resulting prompt to
-        the LLM. The purpose is to test the prompt, not the response.
-        You can use it to test your prompt, including the context provided
-        by the vector database's doc retrieval.
-        The only thing the dry run does not consider is the cut-off due to
-        the `max_tokens` parameter.
-        :param where: Optional. A dictionary of key-value pairs to filter the database results.
-        :return: The answer to the query.
+        :type input_query: str
+        :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
+        To persistently use a config, declare it during app init., defaults to None
+        :type config: Optional[BaseLlmConfig], optional
+        :param dry_run: A dry run does everything except send the resulting prompt to
+        the LLM. The purpose is to test the prompt, not the response., defaults to False
+        :type dry_run: bool, optional
+        :param where: A dictionary of key-value pairs to filter the database results., defaults to None
+        :type where: Optional[Dict[str, str]], optional
+        :return: The answer to the query or the dry run result
+        :rtype: str
         """
         contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
         answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
@@ -359,15 +401,18 @@ class EmbedChain(JSONSerializable):
 
         return answer
 
-    def set_collection(self, collection_name):
+    def set_collection_name(self, name: str):
         """
-        Set the collection to use.
+        Set the name of the collection. A collection is an isolated space for vectors.
+
+        Using `app.db.set_collection_name` method is preferred to this.
 
-        :param collection_name: The name of the collection to use.
+        :param name: Name of the collection.
+        :type name: str
         """
-        self.db.set_collection_name(collection_name)
+        self.db.set_collection_name(name)
         # Create the collection if it does not exist
-        self.db._get_or_create_collection(collection_name)
+        self.db._get_or_create_collection(name)
         # TODO: Check whether it is necessary to assign to the `self.collection` attribute,
         # since the main purpose is the creation.
 
@@ -378,8 +423,9 @@ class EmbedChain(JSONSerializable):
         DEPRECATED IN FAVOR OF `db.count()`
 
         :return: The number of embeddings.
+        :rtype: int
         """
-        logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.")
+        logging.warning("DEPRECATION WARNING: Please use `app.db.count()` instead of `app.count()`.")
         return self.db.count()
 
     def reset(self):
@@ -393,11 +439,14 @@ class EmbedChain(JSONSerializable):
         thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
         thread_telemetry.start()
 
-        logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.")
+        logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
         self.db.reset()
 
     @retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
     def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
+        """
+        Send telemetry event to the embedchain server. This is anonymous. It can be toggled off in `AppConfig`.
+        """
         if not self.config.collect_metrics:
             return
 

+ 25 - 1
embedchain/embedder/base_embedder.py

@@ -19,7 +19,13 @@ class BaseEmbedder:
     To manually overwrite you can use this classes `set_...` methods.
     """
 
-    def __init__(self, config: Optional[BaseEmbedderConfig] = FileNotFoundError):
+    def __init__(self, config: Optional[BaseEmbedderConfig] = None):
+        """
+        Intialize the embedder class.
+
+        :param config: embedder configuration option class, defaults to None
+        :type config: Optional[BaseEmbedderConfig], optional
+        """
         if config is None:
             self.config = BaseEmbedderConfig()
         else:
@@ -27,17 +33,35 @@ class BaseEmbedder:
         self.vector_dimension: int
 
     def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
+        """
+        Set or overwrite the embedding function to be used by the database to store and retrieve documents.
+
+        :param embedding_fn: Function to be used to generate embeddings.
+        :type embedding_fn: Callable[[list[str]], list[str]]
+        :raises ValueError: Embedding function is not callable.
+        """
         if not hasattr(embedding_fn, "__call__"):
             raise ValueError("Embedding function is not a function")
         self.embedding_fn = embedding_fn
 
     def set_vector_dimension(self, vector_dimension: int):
+        """
+        Set or overwrite the vector dimension size
+
+        :param vector_dimension: vector dimension size
+        :type vector_dimension: int
+        """
         self.vector_dimension = vector_dimension
 
     @staticmethod
     def _langchain_default_concept(embeddings: Any):
         """
         Langchains default function layout for embeddings.
+
+        :param embeddings: Langchain embeddings
+        :type embeddings: Any
+        :return: embedding function
+        :rtype: Callable
         """
 
         def embed_function(texts: Documents) -> Embeddings:

+ 94 - 36
embedchain/llm/base_llm.py

@@ -1,5 +1,5 @@
 import logging
-from typing import List, Optional
+from typing import Any, Dict, Generator, List, Optional
 
 from langchain.memory import ConversationBufferMemory
 from langchain.schema import BaseMessage
@@ -13,6 +13,11 @@ from embedchain.helper_classes.json_serializable import JSONSerializable
 
 class BaseLlm(JSONSerializable):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
+        """Initialize a base LLM class
+
+        :param config: LLM configuration option class, defaults to None
+        :type config: Optional[BaseLlmConfig], optional
+        """
         if config is None:
             self.config = BaseLlmConfig()
         else:
@@ -21,7 +26,7 @@ class BaseLlm(JSONSerializable):
         self.memory = ConversationBufferMemory()
         self.is_docs_site_instance = False
         self.online = False
-        self.history: any = None
+        self.history: Any = None
 
     def get_llm_model_answer(self):
         """
@@ -29,24 +34,33 @@ class BaseLlm(JSONSerializable):
         """
         raise NotImplementedError
 
-    def set_history(self, history: any):
+    def set_history(self, history: Any):
+        """
+        Provide your own history.
+        Especially interesting for the query method, which does not internally manage conversation history.
+
+        :param history: History to set
+        :type history: Any
+        """
         self.history = history
 
     def update_history(self):
+        """Update class history attribute with history in memory (for chat method)"""
         chat_history = self.memory.load_memory_variables({})["history"]
         if chat_history:
             self.set_history(chat_history)
 
-    def generate_prompt(self, input_query, contexts, **kwargs):
+    def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
         """
         Generates a prompt based on the given query and context, ready to be
         passed to an LLM
 
         :param input_query: The query to use.
+        :type input_query: str
         :param contexts: List of similar documents to the query used as context.
-        :param config: Optional. The `QueryConfig` instance to use as
-        configuration options.
+        :type contexts: List[str]
         :return: The prompt
+        :rtype: str
         """
         context_string = (" | ").join(contexts)
         web_search_result = kwargs.get("web_search_result", "")
@@ -73,36 +87,67 @@ class BaseLlm(JSONSerializable):
                 )
         return prompt
 
-    def _append_search_and_context(self, context, web_search_result):
+    def _append_search_and_context(self, context: str, web_search_result: str) -> str:
+        """Append web search context to existing context
+
+        :param context: Existing context
+        :type context: str
+        :param web_search_result: Web search result
+        :type web_search_result: str
+        :return: Concatenated web search result
+        :rtype: str
+        """
         return f"{context}\nWeb Search Result: {web_search_result}"
 
-    def get_answer_from_llm(self, prompt):
+    def get_answer_from_llm(self, prompt: str):
         """
         Gets an answer based on the given query and context by passing it
         to an LLM.
 
-        :param query: The query to use.
-        :param context: Similar documents to the query used as context.
+        :param prompt: Gets an answer based on the given query and context by passing it to an LLM.
+        :type prompt: str
         :return: The answer.
+        :rtype: _type_
         """
-
         return self.get_llm_model_answer(prompt)
 
-    def access_search_and_get_results(self, input_query):
+    def access_search_and_get_results(self, input_query: str):
+        """
+        Search the internet for additional context
+
+        :param input_query: search query
+        :type input_query: str
+        :return: Search results
+        :rtype: Unknown
+        """
         from langchain.tools import DuckDuckGoSearchRun
 
         search = DuckDuckGoSearchRun()
         logging.info(f"Access search to get answers for {input_query}")
         return search.run(input_query)
 
-    def _stream_query_response(self, answer):
+    def _stream_query_response(self, answer: Any) -> Generator[Any, Any, None]:
+        """Generator to be used as streaming response
+
+        :param answer: Answer chunk from llm
+        :type answer: Any
+        :yield: Answer chunk from llm
+        :rtype: Generator[Any, Any, None]
+        """
         streamed_answer = ""
         for chunk in answer:
             streamed_answer = streamed_answer + chunk
             yield chunk
         logging.info(f"Answer: {streamed_answer}")
 
-    def _stream_chat_response(self, answer):
+    def _stream_chat_response(self, answer: Any) -> Generator[Any, Any, None]:
+        """Generator to be used as streaming response
+
+        :param answer: Answer chunk from llm
+        :type answer: Any
+        :yield: Answer chunk from llm
+        :rtype: Generator[Any, Any, None]
+        """
         streamed_answer = ""
         for chunk in answer:
             streamed_answer = streamed_answer + chunk
@@ -110,23 +155,24 @@ class BaseLlm(JSONSerializable):
         self.memory.chat_memory.add_ai_message(streamed_answer)
         logging.info(f"Answer: {streamed_answer}")
 
-    def query(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
+    def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
         LLM as context to get the answer.
 
         :param input_query: The query to use.
-        :param config: Optional. The `LlmConfig` instance to use as configuration options.
-        This is used for one method call. To persistently use a config, declare it during app init.
-        :param dry_run: Optional. A dry run does everything except send the resulting prompt to
-        the LLM. The purpose is to test the prompt, not the response.
-        You can use it to test your prompt, including the context provided
-        by the vector database's doc retrieval.
-        The only thing the dry run does not consider is the cut-off due to
-        the `max_tokens` parameter.
-        :param where: Optional. A dictionary of key-value pairs to filter the database results.
-        :return: The answer to the query.
+        :type input_query: str
+        :param contexts: Embeddings retrieved from the database to be used as context.
+        :type contexts: List[str]
+        :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
+        To persistently use a config, declare it during app init., defaults to None
+        :type config: Optional[BaseLlmConfig], optional
+        :param dry_run: A dry run does everything except send the resulting prompt to
+        the LLM. The purpose is to test the prompt, not the response., defaults to False
+        :type dry_run: bool, optional
+        :return: The answer to the query or the dry run result
+        :rtype: str
         """
         query_config = config or self.config
 
@@ -150,24 +196,26 @@ class BaseLlm(JSONSerializable):
         else:
             return self._stream_query_response(answer)
 
-    def chat(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
+    def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
         """
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
         LLM as context to get the answer.
 
         Maintains the whole conversation in memory.
+
         :param input_query: The query to use.
-        :param config: Optional. The `LlmConfig` instance to use as configuration options.
-        This is used for one method call. To persistently use a config, declare it during app init.
-        :param dry_run: Optional. A dry run does everything except send the resulting prompt to
-        the LLM. The purpose is to test the prompt, not the response.
-        You can use it to test your prompt, including the context provided
-        by the vector database's doc retrieval.
-        The only thing the dry run does not consider is the cut-off due to
-        the `max_tokens` parameter.
-        :param where: Optional. A dictionary of key-value pairs to filter the database results.
-        :return: The answer to the query.
+        :type input_query: str
+        :param contexts: Embeddings retrieved from the database to be used as context.
+        :type contexts: List[str]
+        :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
+        To persistently use a config, declare it during app init., defaults to None
+        :type config: Optional[BaseLlmConfig], optional
+        :param dry_run: A dry run does everything except send the resulting prompt to
+        the LLM. The purpose is to test the prompt, not the response., defaults to False
+        :type dry_run: bool, optional
+        :return: The answer to the query or the dry run result
+        :rtype: str
         """
         query_config = config or self.config
 
@@ -205,6 +253,16 @@ class BaseLlm(JSONSerializable):
 
     @staticmethod
     def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
+        """
+        Construct a list of langchain messages
+
+        :param prompt: User prompt
+        :type prompt: str
+        :param system_prompt: System prompt, defaults to None
+        :type system_prompt: Optional[str], optional
+        :return: List of messages
+        :rtype: List[BaseMessage]
+        """
         from langchain.schema import HumanMessage, SystemMessage
 
         messages = []

+ 31 - 1
embedchain/vectordb/base_vector_db.py

@@ -7,6 +7,11 @@ class BaseVectorDB(JSONSerializable):
     """Base class for vector database."""
 
     def __init__(self, config: BaseVectorDbConfig):
+        """Initialize the database. Save the config and client as an attribute.
+
+        :param config: Database configuration class instance.
+        :type config: BaseVectorDbConfig
+        """
         self.client = self._get_or_create_db()
         self.config: BaseVectorDbConfig = config
 
@@ -23,25 +28,50 @@ class BaseVectorDB(JSONSerializable):
         raise NotImplementedError
 
     def _get_or_create_collection(self):
+        """Get or create a named collection."""
         raise NotImplementedError
 
     def _set_embedder(self, embedder: BaseEmbedder):
+        """
+        The database needs to access the embedder sometimes, with this method you can persistently set it.
+
+        :param embedder: Embedder to be set as the embedder for this database.
+        :type embedder: BaseEmbedder
+        """
         self.embedder = embedder
 
     def get(self):
+        """Get database embeddings by id."""
         raise NotImplementedError
 
     def add(self):
+        """Add to database"""
         raise NotImplementedError
 
     def query(self):
+        """Query contents from vector data base based on vector similarity"""
         raise NotImplementedError
 
-    def count(self):
+    def count(self) -> int:
+        """
+        Count number of documents/chunks embedded in the database.
+
+        :return: number of documents
+        :rtype: int
+        """
         raise NotImplementedError
 
     def reset(self):
+        """
+        Resets the database. Deletes all embeddings irreversibly.
+        """
         raise NotImplementedError
 
     def set_collection_name(self, name: str):
+        """
+        Set the name of the collection. A collection is an isolated space for vectors.
+
+        :param name: Name of the collection.
+        :type name: str
+        """
         raise NotImplementedError

+ 59 - 16
embedchain/vectordb/chroma_db.py

@@ -1,6 +1,7 @@
 import logging
-from typing import Any, Dict, List, Optional
+from typing import Dict, List, Optional
 
+from chromadb import Collection, QueryResult
 from langchain.docstore.document import Document
 
 from embedchain.config import ChromaDbConfig
@@ -25,6 +26,11 @@ class ChromaDB(BaseVectorDB):
     """Vector database using ChromaDB."""
 
     def __init__(self, config: Optional[ChromaDbConfig] = None):
+        """Initialize a new ChromaDB instance
+
+        :param config: Configuration options for Chroma, defaults to None
+        :type config: Optional[ChromaDbConfig], optional
+        """
         if config:
             self.config = config
         else:
@@ -60,11 +66,19 @@ class ChromaDB(BaseVectorDB):
         self._get_or_create_collection(self.config.collection_name)
 
     def _get_or_create_db(self):
-        """Get or create the database."""
+        """Called during initialization"""
         return self.client
 
-    def _get_or_create_collection(self, name):
-        """Get or create the collection."""
+    def _get_or_create_collection(self, name: str) -> Collection:
+        """
+        Get or create a named collection.
+
+        :param name: Name of the collection
+        :type name: str
+        :raises ValueError: No embedder configured.
+        :return: Created collection
+        :rtype: Collection
+        """
         if not hasattr(self, "embedder") or not self.embedder:
             raise ValueError("Cannot create a Chroma database collection without an embedder.")
         self.collection = self.client.get_or_create_collection(
@@ -76,8 +90,13 @@ class ChromaDB(BaseVectorDB):
     def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
         """
         Get existing doc ids present in vector database
-        :param ids: list of doc ids to check for existance
+
+        :param ids: list of doc ids to check for existence
+        :type ids: List[str]
         :param where: Optional. to filter data
+        :type where: Dict[str, any]
+        :return: Existing documents.
+        :rtype: List[str]
         """
         existing_docs = self.collection.get(
             ids=ids,
@@ -86,16 +105,28 @@ class ChromaDB(BaseVectorDB):
 
         return set(existing_docs["ids"])
 
-    def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
+    def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
         """
-        add data in vector database
-        :param documents: list of texts to add
-        :param metadatas: list of metadata associated with docs
-        :param ids: ids of docs
+        Add vectors to chroma database
+
+        :param documents: Documents
+        :type documents: List[str]
+        :param metadatas: Metadatas
+        :type metadatas: List[object]
+        :param ids: ids
+        :type ids: List[str]
         """
         self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
 
-    def _format_result(self, results):
+    def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
+        """
+        Format Chroma results
+
+        :param results: ChromaDB query results to format.
+        :type results: QueryResult
+        :return: Formatted results
+        :rtype: list[tuple[Document, float]]
+        """
         return [
             (Document(page_content=result[0], metadata=result[1] or {}), result[2])
             for result in zip(
@@ -107,11 +138,17 @@ class ChromaDB(BaseVectorDB):
 
     def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
         """
-        query contents from vector data base based on vector similarity
+        Query contents from vector data base based on vector similarity
+
         :param input_query: list of query string
+        :type input_query: List[str]
         :param n_results: no of similar documents to fetch from database
-        :param where: Optional. to filter data
+        :type n_results: int
+        :param where: to filter data
+        :type where: Dict[str, any]
+        :raises InvalidDimensionException: Dimensions do not match.
         :return: The content of the document that matched your query.
+        :rtype: List[str]
         """
         try:
             result = self.collection.query(
@@ -132,21 +169,27 @@ class ChromaDB(BaseVectorDB):
         return contents
 
     def set_collection_name(self, name: str):
+        """
+        Set the name of the collection. A collection is an isolated space for vectors.
+
+        :param name: Name of the collection.
+        :type name: str
+        """
         self.config.collection_name = name
         self._get_or_create_collection(self.config.collection_name)
 
     def count(self) -> int:
         """
-        Count the number of embeddings.
+        Count number of documents/chunks embedded in the database.
 
-        :return: The number of embeddings.
+        :return: number of documents
+        :rtype: int
         """
         return self.collection.count()
 
     def reset(self):
         """
         Resets the database. Deletes all embeddings irreversibly.
-        `App` does not have to be reinitialized after using this method.
         """
         # Delete all data from the database
         try:

+ 58 - 16
embedchain/vectordb/elasticsearch_db.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, List
+from typing import Dict, List, Optional, Set
 
 try:
     from elasticsearch import Elasticsearch
@@ -15,16 +15,23 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
 
 @register_deserializable
 class ElasticsearchDB(BaseVectorDB):
+    """
+    Elasticsearch as vector database
+    """
+
     def __init__(
         self,
-        config: ElasticsearchDBConfig = None,
-        es_config: ElasticsearchDBConfig = None,  # Backwards compatibility
+        config: Optional[ElasticsearchDBConfig] = None,
+        es_config: Optional[ElasticsearchDBConfig] = None,  # Backwards compatibility
     ):
-        """
-        Elasticsearch as vector database
-        :param es_config. elasticsearch database config to be used for connection
-        :param embedding_fn: Function to generate embedding vectors.
-        :param vector_dim: Vector dimension generated by embedding fn
+        """Elasticsearch as vector database.
+
+        :param config: Elasticsearch database config, defaults to None
+        :type config: ElasticsearchDBConfig, optional
+        :param es_config: `es_config` is supported as an alias for `config` (for backwards compatibility),
+        defaults to None
+        :type es_config: ElasticsearchDBConfig, optional
+        :raises ValueError: No config provided
         """
         if config is None and es_config is None:
             raise ValueError("ElasticsearchDBConfig is required")
@@ -53,16 +60,22 @@ class ElasticsearchDB(BaseVectorDB):
             self.client.indices.create(index=es_index, body=index_settings)
 
     def _get_or_create_db(self):
+        """Called during initialization"""
         return self.client
 
     def _get_or_create_collection(self, name):
         """Note: nothing to return here. Discuss later"""
 
-    def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
+    def get(self, ids: List[str], where: Dict[str, any]) -> Set[str]:
         """
         Get existing doc ids present in vector database
-        :param ids: list of doc ids to check for existance
-        :param where: Optional. to filter data
+
+        :param ids: _list of doc ids to check for existance
+        :type ids: List[str]
+        :param where: to filter data
+        :type where: Dict[str, any]
+        :return: ids
+        :rtype: Set[str]
         """
         query = {"bool": {"must": [{"ids": {"values": ids}}]}}
         if "app_id" in where:
@@ -73,13 +86,17 @@ class ElasticsearchDB(BaseVectorDB):
         ids = [doc["_id"] for doc in docs]
         return set(ids)
 
-    def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
-        """
-        add data in vector database
+    def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
+        """add data in vector database
+
         :param documents: list of texts to add
+        :type documents: List[str]
         :param metadatas: list of metadata associated with docs
+        :type metadatas: List[object]
         :param ids: ids of docs
+        :type ids: List[str]
         """
+
         docs = []
         embeddings = self.embedder.embedding_fn(documents)
         for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
@@ -92,14 +109,19 @@ class ElasticsearchDB(BaseVectorDB):
             )
         bulk(self.client, docs)
         self.client.indices.refresh(index=self._get_index())
-        return
 
     def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
         """
         query contents from vector data base based on vector similarity
+
         :param input_query: list of query string
+        :type input_query: List[str]
         :param n_results: no of similar documents to fetch from database
+        :type n_results: int
         :param where: Optional. to filter data
+        :type where: Dict[str, any]
+        :return: Database contents that are the result of the query
+        :rtype: List[str]
         """
         input_query_vector = self.embedder.embedding_fn(input_query)
         query_vector = input_query_vector[0]
@@ -122,21 +144,41 @@ class ElasticsearchDB(BaseVectorDB):
         return contents
 
     def set_collection_name(self, name: str):
+        """
+        Set the name of the collection. A collection is an isolated space for vectors.
+
+        :param name: Name of the collection.
+        :type name: str
+        """
         self.config.collection_name = name
 
     def count(self) -> int:
+        """
+        Count number of documents/chunks embedded in the database.
+
+        :return: number of documents
+        :rtype: int
+        """
         query = {"match_all": {}}
         response = self.client.count(index=self._get_index(), query=query)
         doc_count = response["count"]
         return doc_count
 
     def reset(self):
+        """
+        Resets the database. Deletes all embeddings irreversibly.
+        """
         # Delete all data from the database
         if self.client.indices.exists(index=self._get_index()):
             # delete index in Es
             self.client.indices.delete(index=self._get_index())
 
-    def _get_index(self):
+    def _get_index(self) -> str:
+        """Get the Elasticsearch index for a collection
+
+        :return: Elasticsearch index
+        :rtype: str
+        """
         # NOTE: The method is preferred to an attribute, because if collection name changes,
         # it's always up-to-date.
         return f"{self.config.collection_name}_{self.embedder.vector_dimension}"

+ 19 - 19
tests/vectordb/test_chroma_db.py

@@ -121,9 +121,9 @@ class TestChromaDbDuplicateHandling:
         self.app_with_settings.reset()
 
         app = App(config=AppConfig(collect_metrics=False))
-        app.set_collection("test_collection_1")
+        app.set_collection_name("test_collection_1")
         app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
-        app.set_collection("test_collection_2")
+        app.set_collection_name("test_collection_2")
         app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
         assert "Insert of existing embedding ID: 0" not in caplog.text  # not
         assert "Add of existing embedding ID: 0" not in caplog.text  # not
@@ -149,16 +149,16 @@ class TestChromaDbCollection(unittest.TestCase):
         """
         config = AppConfig(collect_metrics=False)
         app = App(config=config)
-        app.set_collection(collection_name="test_collection")
+        app.set_collection_name(name="test_collection")
 
         self.assertEqual(app.db.collection.name, "test_collection")
 
-    def test_set_collection(self):
+    def test_set_collection_name(self):
         """
-        Test if the `App` collection is correctly switched using the `set_collection` method.
+        Test if the `App` collection is correctly switched using the `set_collection_name` method.
         """
         app = App(config=AppConfig(collect_metrics=False))
-        app.set_collection("test_collection")
+        app.set_collection_name("test_collection")
 
         self.assertEqual(app.db.collection.name, "test_collection")
 
@@ -170,7 +170,7 @@ class TestChromaDbCollection(unittest.TestCase):
         self.app_with_settings.reset()
 
         app = App(config=AppConfig(collect_metrics=False))
-        app.set_collection("test_collection_1")
+        app.set_collection_name("test_collection_1")
         # Collection should be empty when created
         self.assertEqual(app.count(), 0)
 
@@ -178,13 +178,13 @@ class TestChromaDbCollection(unittest.TestCase):
         # After adding, should contain one item
         self.assertEqual(app.count(), 1)
 
-        app.set_collection("test_collection_2")
+        app.set_collection_name("test_collection_2")
         # New collection is empty
         self.assertEqual(app.count(), 0)
 
         # Adding to new collection should not effect existing collection
         app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
-        app.set_collection("test_collection_1")
+        app.set_collection_name("test_collection_1")
         # Should still be 1, not 2.
         self.assertEqual(app.count(), 1)
 
@@ -196,12 +196,12 @@ class TestChromaDbCollection(unittest.TestCase):
         self.app_with_settings.reset()
 
         app = App(config=AppConfig(collect_metrics=False))
-        app.set_collection("test_collection_1")
+        app.set_collection_name("test_collection_1")
         app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
         del app
 
         app = App(config=AppConfig(collect_metrics=False))
-        app.set_collection("test_collection_1")
+        app.set_collection_name("test_collection_1")
         self.assertEqual(app.count(), 1)
 
     def test_parallel_collections(self):
@@ -227,9 +227,9 @@ class TestChromaDbCollection(unittest.TestCase):
         app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
 
         # Swap names and test
-        app1.set_collection("test_collection_2")
+        app1.set_collection_name("test_collection_2")
         self.assertEqual(app1.count(), 1)
-        app2.set_collection("test_collection_1")
+        app2.set_collection_name("test_collection_1")
         self.assertEqual(app2.count(), 3)
 
     def test_ids_share_collections(self):
@@ -241,9 +241,9 @@ class TestChromaDbCollection(unittest.TestCase):
 
         # Create two apps
         app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
-        app1.set_collection("one_collection")
+        app1.set_collection_name("one_collection")
         app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
-        app2.set_collection("one_collection")
+        app2.set_collection_name("one_collection")
 
         # Add data
         app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
@@ -263,13 +263,13 @@ class TestChromaDbCollection(unittest.TestCase):
         # Create four apps.
         # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
         app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
-        app1.set_collection("one_collection")
+        app1.set_collection_name("one_collection")
         app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
-        app2.set_collection("one_collection")
+        app2.set_collection_name("one_collection")
         app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
-        app3.set_collection("three_collection")
+        app3.set_collection_name("three_collection")
         app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
-        app4.set_collection("four_collection")
+        app4.set_collection_name("four_collection")
 
         # Each one of them get data
         app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])