浏览代码

feat: add new custom app (#313)

cachho 2 年之前
父节点
当前提交
adb7206639

+ 25 - 0
docs/advanced/adding_data.mdx

@@ -0,0 +1,25 @@
+---
+title: '➕ Adding Data'
+---
+
+## Add Dataset
+
+- This step assumes that you have already created an `app` instance by either using `App`, `OpenSourceApp` or `CustomApp`. We are calling our app instance as `naval_chat_bot` 🤖
+
+- Now use `.add()` function to add any dataset.
+
+```python
+# naval_chat_bot = App() or
+# naval_chat_bot = OpenSourceApp()
+
+# Embed Online Resources
+naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44")
+naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf")
+naval_chat_bot.add("web_page", "https://nav.al/feedback")
+naval_chat_bot.add("web_page", "https://nav.al/agi")
+
+# Embed Local Resources
+naval_chat_bot.add_local("qna_pair", ("Who is Naval Ravikant?", "Naval Ravikant is an Indian-American entrepreneur and investor."))
+```
+
+The possible formats to add data can be found on the [Supported Data Formats](/advanced/data_types) page.

+ 23 - 75
docs/advanced/app_types.mdx

@@ -2,12 +2,6 @@
 title: '📱 App types'
 ---
 
-Creating a chatbot involves 3 steps:
-
-- ⚙️ Import the App instance
-- 🗃️ Add Dataset
-- 💬 Query or Chat on the dataset and get answers (Interface Types)
-
 ## App Types
 
 We have three types of App.
@@ -16,13 +10,12 @@ We have three types of App.
 
 ```python
 from embedchain import App
-naval_chat_bot = App()
+app = App()
 ```
 
 - `App` uses OpenAI's model, so these are paid models. 💸 You will be charged for embedding model usage and LLM usage.
-
 - `App` uses OpenAI's embedding model to create embeddings for chunks and ChatGPT API as LLM to get answer given the relevant docs. Make sure that you have an OpenAI account and an API key. If you have don't have an API key, you can create one by visiting [this link](https://platform.openai.com/account/api-keys).
-
+- `App` is opinionated. It uses the best embedding model and LLM on the market.
 - Once you have the API key, set it in an environment variable called `OPENAI_API_KEY`
 
 ```python
@@ -34,12 +27,31 @@ os.environ["OPENAI_API_KEY"] = "sk-xxxx"
 
 ```python
 from embedchain import OpenSourceApp
-naval_chat_bot = OpenSourceApp()
+app = OpenSourceApp()
 ```
 
 - `OpenSourceApp` uses open source embedding and LLM model. It uses `all-MiniLM-L6-v2` from Sentence Transformers library as the embedding model and `gpt4all` as the LLM.
 - Here there is no need to setup any api keys. You just need to install embedchain package and these will get automatically installed. 📦
 - Once you have imported and instantiated the app, every functionality from here onwards is the same for either type of app. 📚
+- `OpenSourceApp` is opinionated. It uses the best open source embedding model and LLM on the market.
+
+### CustomApp
+
+```python
+from embedchain import CustomApp
+from embedchain.config import CustomAppConfig
+from embedchain.models import Providers, EmbeddingFunctions
+
+config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI)
+app = CustomApp()
+```
+
+- `CustomApp` is not opinionated.
+- Configuration required. It's for advanced users who want to mix and match different embedding models and LLMs. Configuration required.
+- while it's doing that, it's still providing abstractions through `Providers`.
+- paid and free/open source providers included.
+- Once you have imported and instantiated the app, every functionality from here onwards is the same for either type of app. 📚
+
 
 ### PersonApp
 
@@ -57,25 +69,7 @@ import os
 os.environ["OPENAI_API_KEY"] = "sk-xxxx"
 ```
 
-## Add Dataset
-
-- This step assumes that you have already created an `app` instance by either using `App` or `OpenSourceApp`. We are calling our app instance as `naval_chat_bot` 🤖
-
-- Now use `.add()` function to add any dataset.
-
-```python
-# naval_chat_bot = App() or
-# naval_chat_bot = OpenSourceApp()
-
-# Embed Online Resources
-naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44")
-naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf")
-naval_chat_bot.add("web_page", "https://nav.al/feedback")
-naval_chat_bot.add("web_page", "https://nav.al/agi")
-
-# Embed Local Resources
-naval_chat_bot.add_local("qna_pair", ("Who is Naval Ravikant?", "Naval Ravikant is an Indian-American entrepreneur and investor."))
-```
+#### Compatibility with other apps
 
 - If there is any other app instance in your script or app, you can change the import as
 
@@ -90,49 +84,3 @@ from embedchain import App as ECApp
 from embedchain import OpenSourceApp as ECOSApp
 from embedchain import PersonApp as ECPApp
 ```
-
-## Interface Types
-
-### Query Interface
-
-- This interface is like a question answering bot. It takes a question and gets the answer. It does not maintain context about the previous chats.❓
-
-- To use this, call `.query()` function to get the answer for any query.
-
-```python
-print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?"))
-# answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.
-```
-
-### Chat Interface
-
-- This interface is chat interface where it remembers previous conversation. Right now it remembers 5 conversation by default. 💬
-
-- To use this, call `.chat` function to get the answer for any query.
-
-```python
-print(naval_chat_bot.chat("How to be happy in life?"))
-# answer: The most important trick to being happy is to realize happiness is a skill you develop and a choice you make. You choose to be happy, and then you work at it. It's just like building muscles or succeeding at your job. It's about recognizing the abundance and gifts around you at all times.
-
-print(naval_chat_bot.chat("who is naval ravikant?"))
-# answer: Naval Ravikant is an Indian-American entrepreneur and investor.
-
-print(naval_chat_bot.chat("what did the author say about happiness?"))
-# answer: The author, Naval Ravikant, believes that happiness is a choice you make and a skill you develop. He compares the mind to the body, stating that just as the body can be molded and changed, so can the mind. He emphasizes the importance of being present in the moment and not getting caught up in regrets of the past or worries about the future. By being present and grateful for where you are, you can experience true happiness.
-```
-
-### Stream Response
-
-- You can add config to your query method to stream responses like ChatGPT does. You would require a downstream handler to render the chunk in your desirable format. Supports both OpenAI model and OpenSourceApp. 📊
-
-- To use this, instantiate a `QueryConfig` or `ChatConfig` object with `stream=True`. Then pass it to the `.chat()` or `.query()` method. The following example iterates through the chunks and prints them as they appear.
-
-```python
-app = App()
-query_config = QueryConfig(stream = True)
-resp = app.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?", query_config)
-
-for chunk in resp:
-    print(chunk, end="", flush=True)
-# answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.
-```

+ 4 - 28
docs/advanced/configuration.mdx

@@ -6,7 +6,7 @@ Embedchain is made to work out of the box. However, for advanced users we're als
 
 ## Examples
 
-### Custom embedding function
+### General
 
 Here's the readme example with configuration options.
 
@@ -16,13 +16,8 @@ from embedchain import App
 from embedchain.config import AppConfig, AddConfig, QueryConfig, ChunkerConfig
 from chromadb.utils import embedding_functions
 
-# Example: use your own embedding function
-# Warning: We are currenty reworking the concept of custom apps, this might not be working.
-config = AppConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
-                api_key=os.getenv("OPENAI_API_KEY"),
-                organization_id=os.getenv("OPENAI_ORGANIZATION"),
-                model_name="text-embedding-ada-002"
-            ))
+# Example: set the log level for debugging
+config = AppConfig(log_level="DEBUG")
 naval_chat_bot = App(config)
 
 # Example: define your own chunker config for `youtube_video`
@@ -36,7 +31,7 @@ naval_chat_bot.add("web_page", "https://nav.al/agi", add_config)
 
 naval_chat_bot.add_local("qna_pair", ("Who is Naval Ravikant?", "Naval Ravikant is an Indian-American entrepreneur and investor."), add_config)
 
-query_config = QueryConfig() # Currently no options
+query_config = QueryConfig()
 print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?", query_config))
 ```
 
@@ -88,22 +83,3 @@ for query in queries:
 # Query:  Why did you divorce your first wife?
 # Response:  We divorced due to living apart for five years.
 ```
-
-## Other methods
-
-### Reset
-
-Resets the database and deletes all embeddings. Irreversible. Requires reinitialization afterwards.
-
-```python
-app.reset()
-```
-
-### Count
-
-Counts the number of embeddings (chunks) in the database.
-
-```python
-print(app.count())
-# returns: 481
-```

+ 74 - 0
docs/advanced/interface_types.mdx

@@ -0,0 +1,74 @@
+---
+title: '🤝 Interface types'
+---
+
+## Interface Types
+
+The embedchain app exposes the following methods.
+
+### Query Interface
+
+- This interface is like a question answering bot. It takes a question and gets the answer. It does not maintain context about the previous chats.❓
+
+- To use this, call `.query()` function to get the answer for any query.
+
+```python
+print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?"))
+# answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.
+```
+
+### Chat Interface
+
+- This interface is chat interface where it remembers previous conversation. Right now it remembers 5 conversation by default. 💬
+
+- To use this, call `.chat` function to get the answer for any query.
+
+```python
+print(naval_chat_bot.chat("How to be happy in life?"))
+# answer: The most important trick to being happy is to realize happiness is a skill you develop and a choice you make. You choose to be happy, and then you work at it. It's just like building muscles or succeeding at your job. It's about recognizing the abundance and gifts around you at all times.
+
+print(naval_chat_bot.chat("who is naval ravikant?"))
+# answer: Naval Ravikant is an Indian-American entrepreneur and investor.
+
+print(naval_chat_bot.chat("what did the author say about happiness?"))
+# answer: The author, Naval Ravikant, believes that happiness is a choice you make and a skill you develop. He compares the mind to the body, stating that just as the body can be molded and changed, so can the mind. He emphasizes the importance of being present in the moment and not getting caught up in regrets of the past or worries about the future. By being present and grateful for where you are, you can experience true happiness.
+```
+
+### Stream Response
+
+- You can add config to your query method to stream responses like ChatGPT does. You would require a downstream handler to render the chunk in your desirable format. Supports both OpenAI model and OpenSourceApp. 📊
+
+- To use this, instantiate a `QueryConfig` or `ChatConfig` object with `stream=True`. Then pass it to the `.chat()` or `.query()` method. The following example iterates through the chunks and prints them as they appear.
+
+```python
+app = App()
+query_config = QueryConfig(stream = True)
+resp = app.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?", query_config)
+
+for chunk in resp:
+    print(chunk, end="", flush=True)
+# answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.
+```
+
+### Other Methods
+
+#### Dry Run
+
+Dry run has all the options that `query` has, it just doesn't send the prompt to the LLM, to save money. It's used for [testing](/advanced/testing#dry-run).
+
+#### Reset
+
+Resets the database and deletes all embeddings. Irreversible. Requires reinitialization afterwards.
+
+```python
+app.reset()
+```
+
+#### Count
+
+Counts the number of embeddings (chunks) in the database.
+
+```python
+print(app.count())
+# returns: 481
+```

+ 5 - 5
docs/advanced/query_configuration.mdx

@@ -4,11 +4,11 @@ title: '🔍 Query configurations'
 
 ## AppConfig
 
-| option    | description           | type                            | default                |
-|-----------|-----------------------|---------------------------------|------------------------|
-| log_level | log level             | string                          | WARNING                |
-| ef        | embedding function    | chromadb.utils.embedding_functions | \{text-embedding-ada-002\} |
-| db        | vector database (experimental) | BaseVectorDB               | ChromaDB               |
+| option      | description           | type                            | default                |
+|-------------|-----------------------|---------------------------------|------------------------|
+| log_level   | log level             | string                          | WARNING                |
+| embedding_fn| embedding function    | chromadb.utils.embedding_functions | \{text-embedding-ada-002\} |
+| db          | vector database (experimental) | BaseVectorDB               | ChromaDB               |
 
 
 ## AddConfig

+ 4 - 0
docs/advanced/testing.mdx

@@ -2,6 +2,10 @@
 title: '🧪 Testing'
 ---
 
+## Methods for testing
+
+### Dry Run
+
 Before you consume valueable tokens, you should make sure that the embedding you have done works and that it's receiving the correct document from the database.
 
 For this you can use the `dry_run` method.

+ 1 - 1
docs/mint.json

@@ -32,7 +32,7 @@
     },
     {
       "group": "Advanced",
-      "pages": ["advanced/app_types", "advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/showcase"]
+      "pages": ["advanced/app_types", "advanced/interface_types", "advanced/adding_data","advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/showcase"]
     },
     {
       "group": "Contribution Guidelines",

+ 6 - 0
docs/quickstart.mdx

@@ -9,6 +9,12 @@ Install embedchain python package:
 pip install embedchain
 ```
 
+Creating a chatbot involves 3 steps:
+
+- ⚙️ Import the App instance
+- 🗃️ Add Dataset
+- 💬 Query or Chat on the dataset and get answers (Interface Types)
+
 Run your first bot in python using the following code. Make sure to set the `OPENAI_API_KEY` 🔑 environment variable in the code.
 
 ```python

+ 2 - 2
embedchain/__init__.py

@@ -3,6 +3,6 @@ import importlib.metadata
 __version__ = importlib.metadata.version(__package__ or __name__)
 
 from embedchain.apps.App import App  # noqa: F401
+from embedchain.apps.CustomApp import CustomApp  # noqa: F401
 from embedchain.apps.OpenSourceApp import OpenSourceApp  # noqa: F401
-from embedchain.apps.PersonApp import (PersonApp,  # noqa: F401
-                                       PersonOpenSourceApp)
+from embedchain.apps.PersonApp import PersonApp, PersonOpenSourceApp  # noqa: F401

+ 1 - 1
embedchain/apps/App.py

@@ -27,7 +27,7 @@ class App(EmbedChain):
         messages = []
         messages.append({"role": "user", "content": prompt})
         response = openai.ChatCompletion.create(
-            model=config.model,
+            model=config.model or "gpt-3.5-turbo-0613",
             messages=messages,
             temperature=config.temperature,
             max_tokens=config.max_tokens,

+ 128 - 0
embedchain/apps/CustomApp.py

@@ -0,0 +1,128 @@
+import logging
+from typing import Iterable, List, Union
+
+from langchain.schema import BaseMessage
+
+from embedchain.config import ChatConfig, CustomAppConfig, OpenSourceAppConfig
+from embedchain.embedchain import EmbedChain
+from embedchain.models import Providers
+
+
+class CustomApp(EmbedChain):
+    """
+    The custom EmbedChain app.
+    Has two functions: add and query.
+
+    adds(data_type, url): adds the data from the given URL to the vector db.
+    query(query): finds answer to the given query using vector database and LLM.
+    dry_run(query): test your prompt without consuming tokens.
+    """
+
+    def __init__(self, config: CustomAppConfig = None):
+        """
+        :param config: Optional. `CustomAppConfig` instance to load as configuration.
+        :raises ValueError: Config must be provided for custom app
+        """
+        if config is None:
+            raise ValueError("Config must be provided for custom app")
+
+        self.provider = config.provider
+
+        if config.provider == Providers.GPT4ALL:
+            from embedchain import OpenSourceApp
+
+            # Because these models run locally, they should have an instance running when the custom app is created
+            self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
+
+        super().__init__(config)
+
+    def set_llm_model(self, provider: Providers):
+        self.provider = provider
+        if provider == Providers.GPT4ALL:
+            raise ValueError(
+                "GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
+            )
+
+    def get_llm_model_answer(self, prompt, config: ChatConfig):
+        # TODO: Quitting the streaming response here for now.
+        # Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
+        if config.stream:
+            raise NotImplementedError(
+                "Streaming responses have not been implemented for this model yet. Please disable."
+            )
+
+        try:
+            if self.provider == Providers.OPENAI:
+                return CustomApp._get_openai_answer(prompt, config)
+
+            if self.provider == Providers.ANTHROPHIC:
+                return CustomApp._get_athrophic_answer(prompt, config)
+
+            if self.provider == Providers.VERTEX_AI:
+                return CustomApp._get_vertex_answer(prompt, config)
+
+            if self.provider == Providers.GPT4ALL:
+                return self.open_source_app._get_gpt4all_answer(prompt, config)
+
+        except ImportError as e:
+            raise ImportError(e.msg) from None
+
+    @staticmethod
+    def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
+        from langchain.chat_models import ChatOpenAI
+
+        logging.info(vars(config))
+
+        chat = ChatOpenAI(
+            temperature=config.temperature,
+            model=config.model or "gpt-3.5-turbo",
+            max_tokens=config.max_tokens,
+            streaming=config.stream,
+        )
+
+        if config.top_p and config.top_p != 1:
+            logging.warning("Config option `top_p` is not supported by this model.")
+
+        messages = CustomApp._get_messages(prompt)
+
+        return chat(messages).content
+
+    @staticmethod
+    def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
+        from langchain.chat_models import ChatAnthropic
+
+        chat = ChatAnthropic(temperature=config.temperature, model=config.model)
+
+        if config.max_tokens and config.max_tokens != 1000:
+            logging.warning("Config option `max_tokens` is not supported by this model.")
+
+        messages = CustomApp._get_messages(prompt)
+
+        return chat(messages).content
+
+    @staticmethod
+    def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
+        from langchain.chat_models import ChatVertexAI
+
+        chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
+
+        if config.top_p and config.top_p != 1:
+            logging.warning("Config option `top_p` is not supported by this model.")
+
+        messages = CustomApp._get_messages(prompt)
+
+        return chat(messages).content
+
+    @staticmethod
+    def _get_messages(prompt: str) -> List[BaseMessage]:
+        from langchain.schema import HumanMessage, SystemMessage
+
+        return [SystemMessage(content="You are a helpful assistant."), HumanMessage(content=prompt)]
+
+    def _stream_llm_model_response(self, response):
+        """
+        This is a generator for streaming response from the OpenAI completions API
+        """
+        for line in response:
+            chunk = line["choices"][0].get("delta", {}).get("content", "")
+            yield chunk

+ 31 - 5
embedchain/apps/OpenSourceApp.py

@@ -1,4 +1,5 @@
 import logging
+from typing import Iterable, List, Union
 
 from embedchain.config import ChatConfig, OpenSourceAppConfig
 from embedchain.embedchain import EmbedChain
@@ -26,14 +27,39 @@ class OpenSourceApp(EmbedChain):
         if not config:
             config = OpenSourceAppConfig()
 
+        if not config.model:
+            raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?")
+
+        self.instance = OpenSourceApp._get_instance(config.model)
+
         logging.info("Successfully loaded open source embedding model.")
         super().__init__(config)
 
     def get_llm_model_answer(self, prompt, config: ChatConfig):
-        from gpt4all import GPT4All
+        return self._get_gpt4all_answer(prompt=prompt, config=config)
+
+    @staticmethod
+    def _get_instance(model):
+        try:
+            from gpt4all import GPT4All
+        except ModuleNotFoundError:
+            raise ValueError(
+                "The GPT4All python package is not installed. Please install it with `pip install GPT4All`"
+            ) from None
+
+        return GPT4All(model)
+
+    def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Iterable]:
+        if config.model and config.model != self.config.model:
+            raise RuntimeError(
+                "OpenSourceApp does not support switching models at runtime. Please create a new app instance."
+            )
 
-        global gpt4all_model
-        if gpt4all_model is None:
-            gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
-        response = gpt4all_model.generate(prompt=prompt, streaming=config.stream)
+        response = self.instance.generate(
+            prompt=prompt,
+            streaming=config.stream,
+            top_p=config.top_p,
+            max_tokens=config.max_tokens,
+            temp=config.temperature,
+        )
         return response

+ 1 - 2
embedchain/apps/PersonApp.py

@@ -4,8 +4,7 @@ from embedchain.apps.App import App
 from embedchain.apps.OpenSourceApp import OpenSourceApp
 from embedchain.config import ChatConfig, QueryConfig
 from embedchain.config.apps.BaseAppConfig import BaseAppConfig
-from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
-                                           DEFAULT_PROMPT_WITH_HISTORY)
+from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
 
 
 class EmbedChainPersonApp:

+ 1 - 1
embedchain/config/QueryConfig.py

@@ -104,7 +104,7 @@ class QueryConfig(BaseConfig):
 
         self.temperature = temperature if temperature else 0
         self.max_tokens = max_tokens if max_tokens else 1000
-        self.model = model if model else "gpt-3.5-turbo-0613"
+        self.model = model
         self.top_p = top_p if top_p else 1
 
         if self.validate_template(template):

+ 3 - 1
embedchain/config/apps/AppConfig.py

@@ -18,7 +18,9 @@ class AppConfig(BaseAppConfig):
         :param host: Optional. Hostname for the database server.
         :param port: Optional. Port for the database server.
         """
-        super().__init__(log_level=log_level, ef=AppConfig.default_embedding_function(), host=host, port=port, id=id)
+        super().__init__(
+            log_level=log_level, embedding_fn=AppConfig.default_embedding_function(), host=host, port=port, id=id
+        )
 
     @staticmethod
     def default_embedding_function():

+ 7 - 7
embedchain/config/apps/BaseAppConfig.py

@@ -8,11 +8,11 @@ class BaseAppConfig(BaseConfig):
     Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
     """
 
-    def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
+    def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None):
         """
         :param log_level: Optional. (String) Debug level
         ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
-        :param ef: Embedding function to use.
+        :param embedding_fn: Embedding function to use.
         :param db: Optional. (Vector) database instance to use for embeddings.
         :param id: Optional. ID of the app. Document metadata will have this id.
         :param host: Optional. Hostname for the database server.
@@ -20,26 +20,26 @@ class BaseAppConfig(BaseConfig):
         """
         self._setup_logging(log_level)
 
-        self.db = db if db else BaseAppConfig.default_db(ef=ef, host=host, port=port)
+        self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port)
         self.id = id
         return
 
     @staticmethod
-    def default_db(ef, host, port):
+    def default_db(embedding_fn, host, port):
         """
         Sets database to default (`ChromaDb`).
 
-        :param ef: Embedding function to use in database.
+        :param embedding_fn: Embedding function to use in database.
         :param host: Optional. Hostname for the database server.
         :param port: Optional. Port for the database server.
         :returns: Default database
         :raises ValueError: BaseAppConfig knows no default embedding function.
         """
-        if ef is None:
+        if embedding_fn is None:
             raise ValueError("ChromaDb cannot be instantiated without an embedding function")
         from embedchain.vectordb.chroma_db import ChromaDB
 
-        return ChromaDB(ef=ef, host=host, port=port)
+        return ChromaDB(embedding_fn=embedding_fn, host=host, port=port)
 
     def _setup_logging(self, debug_level):
         level = logging.WARNING  # Default level

+ 88 - 3
embedchain/config/apps/CustomAppConfig.py

@@ -1,4 +1,15 @@
+import logging
+from typing import Any
+
+from chromadb.api.types import Documents, Embeddings
+from dotenv import load_dotenv
+
+from embedchain.models import EmbeddingFunctions, Providers
+
 from .BaseAppConfig import BaseAppConfig
+from embedchain.models import Providers
+
+load_dotenv()
 
 
 class CustomAppConfig(BaseAppConfig):
@@ -6,14 +17,88 @@ class CustomAppConfig(BaseAppConfig):
     Config to initialize an embedchain custom `App` instance, with extra config options.
     """
 
-    def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
+    def __init__(
+        self,
+        log_level=None,
+        embedding_fn: EmbeddingFunctions = None,
+        embedding_fn_model=None,
+        db=None,
+        host=None,
+        port=None,
+        id=None,
+        provider: Providers = None,
+        model=None,
+        open_source_app_config=None,
+    ):
         """
         :param log_level: Optional. (String) Debug level
         ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
-        :param ef: Optional. Embedding function to use.
+        :param embedding_fn: Optional. Embedding function to use.
+        :param embedding_fn_model: Optional. Model name to use for embedding function.
         :param db: Optional. (Vector) database to use for embeddings.
         :param id: Optional. ID of the app. Document metadata will have this id.
         :param host: Optional. Hostname for the database server.
         :param port: Optional. Port for the database server.
+        :param provider: Optional. (Providers): LLM Provider to use.
+        :param open_source_app_config: Optional. Config instance needed for open source apps.
+        """
+        if provider:
+            self.provider = provider
+        else:
+            raise ValueError("CustomApp must have a provider assigned.")
+
+        self.open_source_app_config = open_source_app_config
+
+        super().__init__(
+            log_level=log_level,
+            embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model),
+            db=db,
+            host=host,
+            port=port,
+            id=id,
+        )
+
+    @staticmethod
+    def langchain_default_concept(embeddings: Any):
         """
-        super().__init__(log_level=log_level, db=db, host=host, port=port, id=id)
+        Langchains default function layout for embeddings.
+        """
+
+        def embed_function(texts: Documents) -> Embeddings:
+            return embeddings.embed_documents(texts)
+
+        return embed_function
+
+    @staticmethod
+    def embedding_function(embedding_function: EmbeddingFunctions, model: str = None):
+        if not isinstance(embedding_function, EmbeddingFunctions):
+            raise ValueError(
+                f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}"  # noqa: E501
+            )
+
+        if embedding_function == EmbeddingFunctions.OPENAI:
+            from langchain.embeddings import OpenAIEmbeddings
+
+            if model:
+                embeddings = OpenAIEmbeddings(model=model)
+            else:
+                embeddings = OpenAIEmbeddings()
+            return CustomAppConfig.langchain_default_concept(embeddings)
+
+        elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
+            from langchain.embeddings import HuggingFaceEmbeddings
+
+            embeddings = HuggingFaceEmbeddings(model_name=model)
+            return CustomAppConfig.langchain_default_concept(embeddings)
+
+        elif embedding_function == EmbeddingFunctions.VERTEX_AI:
+            from langchain.embeddings import VertexAIEmbeddings
+
+            embeddings = VertexAIEmbeddings(model_name=model)
+            return CustomAppConfig.langchain_default_concept(embeddings)
+
+        elif embedding_function == EmbeddingFunctions.GPT4ALL:
+            # Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
+            from chromadb.utils import embedding_functions
+
+            return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model)

+ 9 - 2
embedchain/config/apps/OpenSourceAppConfig.py

@@ -8,16 +8,23 @@ class OpenSourceAppConfig(BaseAppConfig):
     Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
     """
 
-    def __init__(self, log_level=None, host=None, port=None, id=None):
+    def __init__(self, log_level=None, host=None, port=None, id=None, model=None):
         """
         :param log_level: Optional. (String) Debug level
         ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
         :param id: Optional. ID of the app. Document metadata will have this id.
         :param host: Optional. Hostname for the database server.
         :param port: Optional. Port for the database server.
+        :param model: Optional. GPT4ALL uses the model to instantiate the class. So unlike `App`, it has to be provided before querying.
         """
+        self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin"
+
         super().__init__(
-            log_level=log_level, ef=OpenSourceAppConfig.default_embedding_function(), host=host, port=port, id=id
+            log_level=log_level,
+            embedding_fn=OpenSourceAppConfig.default_embedding_function(),
+            host=host,
+            port=port,
+            id=id,
         )
 
     @staticmethod

+ 17 - 10
embedchain/embedchain.py

@@ -10,7 +10,7 @@ from embedchain.config.apps.BaseAppConfig import BaseAppConfig
 from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
 from embedchain.data_formatter import DataFormatter
 
-gpt4all_model = None
+from chromadb.errors import InvalidDimensionException
 
 load_dotenv()
 
@@ -26,7 +26,7 @@ class EmbedChain:
         Initializes the EmbedChain instance, sets up a vector DB client and
         creates a collection.
 
-        :param config: InitConfig instance to load as configuration.
+        :param config: BaseAppConfig instance to load as configuration.
         """
 
         self.config = config
@@ -152,14 +152,21 @@ class EmbedChain:
         :param config: The query configuration.
         :return: The content of the document that matched your query.
         """
-        where = {"app_id": self.config.id} if self.config.id is not None else {}  # optional filter
-        result = self.collection.query(
-            query_texts=[
-                input_query,
-            ],
-            n_results=config.number_documents,
-            where=where,
-        )
+        try:
+            where = {"app_id": self.config.id} if self.config.id is not None else {}  # optional filter
+            result = self.collection.query(
+                query_texts=[
+                    input_query,
+                ],
+                n_results=config.number_documents,
+                where=where,
+            )
+        except InvalidDimensionException as e:
+            raise InvalidDimensionException(
+                e.message()
+                + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database."
+            ) from None
+
         results_formatted = self._format_result(result)
         contents = [result[0].page_content for result in results_formatted]
         return contents

+ 8 - 0
embedchain/models/EmbeddingFunctions.py

@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class EmbeddingFunctions(Enum):
+    OPENAI = "OPENAI"
+    HUGGING_FACE = "HUGGING_FACE"
+    VERTEX_AI = "VERTEX_AI"
+    GPT4ALL = "GPT4ALL"

+ 8 - 0
embedchain/models/Providers.py

@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class Providers(Enum):
+    OPENAI = "OPENAI"
+    ANTHROPHIC = "ANTHPROPIC"
+    VERTEX_AI = "VERTEX_AI"
+    GPT4ALL = "GPT4ALL"

+ 2 - 0
embedchain/models/__init__.py

@@ -0,0 +1,2 @@
+from .EmbeddingFunctions import EmbeddingFunctions  # noqa: F401
+from .Providers import Providers  # noqa: F401

+ 6 - 3
embedchain/vectordb/chroma_db.py

@@ -8,8 +8,11 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
 class ChromaDB(BaseVectorDB):
     """Vector database using ChromaDB."""
 
-    def __init__(self, db_dir=None, ef=None, host=None, port=None):
-        self.ef = ef
+    def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None):
+        self.embedding_fn = embedding_fn
+
+        if not hasattr(embedding_fn, "__call__"):
+            raise ValueError("Embedding function is not a function")
 
         if host and port:
             logging.info(f"Connecting to ChromaDB server: {host}:{port}")
@@ -36,5 +39,5 @@ class ChromaDB(BaseVectorDB):
         """Get or create the collection."""
         return self.client.get_or_create_collection(
             "embedchain_store",
-            embedding_function=self.ef,
+            embedding_function=self.embedding_fn,
         )

+ 1 - 1
tests/vectordb/test_chroma_db.py

@@ -17,7 +17,7 @@ class TestChromaDbHosts(unittest.TestCase):
         port = "1234"
 
         with patch.object(chromadb, "Client") as mock_client:
-            _db = ChromaDB(host=host, port=port)
+            _db = ChromaDB(host=host, port=port, embedding_fn=len)
 
         expected_settings = chromadb.config.Settings(
             chroma_api_impl="rest",