Pārlūkot izejas kodu

[Feature] Add support for Mistral API (#1194)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 gadu atpakaļ
vecāks
revīzija
cb0499407e

+ 43 - 0
docs/components/llms.mdx

@@ -20,6 +20,7 @@ Embedchain comes with built-in support for various popular large language models
   <Card title="Hugging Face" href="#hugging-face"></Card>
   <Card title="Llama2" href="#llama2"></Card>
   <Card title="Vertex AI" href="#vertex-ai"></Card>
+  <Card title="Mistral AI" href="#mistral-ai"></Card>
 </CardGroup>
 
 ## OpenAI
@@ -620,5 +621,47 @@ llm:
 ```
 </CodeGroup>
 
+
+## Mistral AI
+
+Obtain the Mistral AI api key from their [console](https://console.mistral.ai/).
+
+<CodeGroup>
+
+```python main.py
+import os
+from embedchain import App
+
+os.environ["MISTRAL_API_KEY"] = "xxx"
+
+app = App.from_config(config_path="config.yaml")
+
+app.add("https://www.forbes.com/profile/elon-musk")
+
+response = app.query("what is the net worth of Elon Musk?")
+# As of January 16, 2024, Elon Musk's net worth is $225.4 billion.
+
+response = app.chat("which companies does elon own?")
+# Elon Musk owns Tesla, SpaceX, Boring Company, Twitter, and X.
+
+response = app.chat("what question did I ask you already?")
+# You have asked me several times already which companies Elon Musk owns, specifically Tesla, SpaceX, Boring Company, Twitter, and X.
+```
+  
+```yaml config.yaml
+llm:
+  provider: mistralai
+  config:
+    model: mistral-tiny
+    temperature: 0.5
+    max_tokens: 1000
+    top_p: 1
+embedder:
+  provider: mistralai
+  config:
+    model: mistral-embed
+```
+</CodeGroup>
+
 <br/ >
 <Snippet file="missing-llm-tip.mdx" />

+ 46 - 0
embedchain/embedder/mistralai.py

@@ -0,0 +1,46 @@
+import os
+from typing import Optional, Union
+
+from chromadb import EmbeddingFunction, Embeddings
+
+from embedchain.config import BaseEmbedderConfig
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.models import VectorDimensions
+
+
+class MistralAIEmbeddingFunction(EmbeddingFunction):
+    def __init__(self, config: BaseEmbedderConfig) -> None:
+        super().__init__()
+        try:
+            from langchain_mistralai import MistralAIEmbeddings
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "The required dependencies for MistralAI are not installed."
+                'Please install with `pip install --upgrade "embedchain[mistralai]"`'
+            ) from None
+        self.config = config
+        api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY")
+        self.client = MistralAIEmbeddings(mistral_api_key=api_key)
+        self.client.model = self.config.model
+
+    def __call__(self, input: Union[list[str], str]) -> Embeddings:
+        if isinstance(input, str):
+            input_ = [input]
+        else:
+            input_ = input
+        response = self.client.embed_documents(input_)
+        return response
+
+
+class MistralAIEmbedder(BaseEmbedder):
+    def __init__(self, config: Optional[BaseEmbedderConfig] = None):
+        super().__init__(config)
+
+        if self.config.model is None:
+            self.config.model = "mistral-embed"
+
+        embedding_fn = MistralAIEmbeddingFunction(config=self.config)
+        self.set_embedding_fn(embedding_fn=embedding_fn)
+
+        vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value
+        self.set_vector_dimension(vector_dimension=vector_dimension)

+ 2 - 0
embedchain/factory.py

@@ -21,6 +21,7 @@ class LlmFactory:
         "openai": "embedchain.llm.openai.OpenAILlm",
         "vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
         "google": "embedchain.llm.google.GoogleLlm",
+        "mistralai": "embedchain.llm.mistralai.MistralAILlm",
     }
     provider_to_config_class = {
         "embedchain": "embedchain.config.llm.base.BaseLlmConfig",
@@ -50,6 +51,7 @@ class EmbedderFactory:
         "openai": "embedchain.embedder.openai.OpenAIEmbedder",
         "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
         "google": "embedchain.embedder.google.GoogleAIEmbedder",
+        "mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
     }
     provider_to_config_class = {
         "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",

+ 52 - 0
embedchain/llm/mistralai.py

@@ -0,0 +1,52 @@
+import os
+from typing import Optional
+
+from embedchain.config import BaseLlmConfig
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.llm.base import BaseLlm
+
+
+@register_deserializable
+class MistralAILlm(BaseLlm):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        super().__init__(config)
+        if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
+            raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
+
+    def get_llm_model_answer(self, prompt):
+        return MistralAILlm._get_answer(prompt=prompt, config=self.config)
+
+    @staticmethod
+    def _get_answer(prompt: str, config: BaseLlmConfig):
+        try:
+            from langchain_core.messages import HumanMessage, SystemMessage
+            from langchain_mistralai.chat_models import ChatMistralAI
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "The required dependencies for MistralAI are not installed."
+                'Please install with `pip install --upgrade "embedchain[mistralai]"`'
+            ) from None
+
+        api_key = config.api_key or os.getenv("MISTRAL_API_KEY")
+        client = ChatMistralAI(mistral_api_key=api_key)
+        messages = []
+        if config.system_prompt:
+            messages.append(SystemMessage(content=config.system_prompt))
+        messages.append(HumanMessage(content=prompt))
+        kwargs = {
+            "model": config.model or "mistral-tiny",
+            "temperature": config.temperature,
+            "max_tokens": config.max_tokens,
+            "top_p": config.top_p,
+        }
+
+        # TODO: Add support for streaming
+        if config.stream:
+            answer = ""
+            for chunk in client.stream(**kwargs, input=messages):
+                answer += chunk.content
+            return answer
+        else:
+            response = client.invoke(**kwargs, input=messages)
+            answer = response.content
+            return answer

+ 1 - 0
embedchain/models/vector_dimensions.py

@@ -8,3 +8,4 @@ class VectorDimensions(Enum):
     VERTEX_AI = 768
     HUGGING_FACE = 384
     GOOGLE_AI = 768
+    MISTRAL_AI = 1024

+ 19 - 2
embedchain/utils/misc.py

@@ -406,6 +406,7 @@ def validate_config(config_data):
                     "llama2",
                     "vertexai",
                     "google",
+                    "mistralai",
                 ),
                 Optional("config"): {
                     Optional("model"): str,
@@ -431,7 +432,15 @@ def validate_config(config_data):
                 Optional("config"): object,  # TODO: add particular config schema for each provider
             },
             Optional("embedder"): {
-                Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
+                Optional("provider"): Or(
+                    "openai",
+                    "gpt4all",
+                    "huggingface",
+                    "vertexai",
+                    "azure_openai",
+                    "google",
+                    "mistralai",
+                ),
                 Optional("config"): {
                     Optional("model"): Optional(str),
                     Optional("deployment_name"): Optional(str),
@@ -442,7 +451,15 @@ def validate_config(config_data):
                 },
             },
             Optional("embedding_model"): {
-                Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
+                Optional("provider"): Or(
+                    "openai",
+                    "gpt4all",
+                    "huggingface",
+                    "vertexai",
+                    "azure_openai",
+                    "google",
+                    "mistralai",
+                ),
                 Optional("config"): {
                     Optional("model"): str,
                     Optional("deployment_name"): str,

+ 126 - 10
poetry.lock

@@ -2487,24 +2487,24 @@ files = [
 
 [[package]]
 name = "httpcore"
-version = "0.18.0"
+version = "1.0.2"
 description = "A minimal low-level HTTP client."
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "httpcore-0.18.0-py3-none-any.whl", hash = "sha256:adc5398ee0a476567bf87467063ee63584a8bce86078bf748e48754f60202ced"},
-    {file = "httpcore-0.18.0.tar.gz", hash = "sha256:13b5e5cd1dca1a6636a6aaea212b19f4f85cd88c366a2b82304181b769aab3c9"},
+    {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"},
+    {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"},
 ]
 
 [package.dependencies]
-anyio = ">=3.0,<5.0"
 certifi = "*"
 h11 = ">=0.13,<0.15"
-sniffio = "==1.*"
 
 [package.extras]
+asyncio = ["anyio (>=4.0,<5.0)"]
 http2 = ["h2 (>=3,<5)"]
 socks = ["socksio (==1.*)"]
+trio = ["trio (>=0.22.0,<0.23.0)"]
 
 [[package]]
 name = "httplib2"
@@ -2569,21 +2569,22 @@ test = ["Cython (>=0.29.24,<0.30.0)"]
 
 [[package]]
 name = "httpx"
-version = "0.25.0"
+version = "0.25.2"
 description = "The next generation HTTP client."
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "httpx-0.25.0-py3-none-any.whl", hash = "sha256:181ea7f8ba3a82578be86ef4171554dd45fec26a02556a744db029a0a27b7100"},
-    {file = "httpx-0.25.0.tar.gz", hash = "sha256:47ecda285389cb32bb2691cc6e069e3ab0205956f681c5b2ad2325719751d875"},
+    {file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"},
+    {file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"},
 ]
 
 [package.dependencies]
+anyio = "*"
 brotli = {version = "*", optional = true, markers = "platform_python_implementation == \"CPython\" and extra == \"brotli\""}
 brotlicffi = {version = "*", optional = true, markers = "platform_python_implementation != \"CPython\" and extra == \"brotli\""}
 certifi = "*"
 h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
-httpcore = ">=0.18.0,<0.19.0"
+httpcore = "==1.*"
 idna = "*"
 sniffio = "*"
 socksio = {version = "==1.*", optional = true, markers = "extra == \"socks\""}
@@ -3024,6 +3025,45 @@ openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"]
 qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"]
 text-helpers = ["chardet (>=5.1.0,<6.0.0)"]
 
+[[package]]
+name = "langchain-core"
+version = "0.1.12"
+description = "Building applications with LLMs through composability"
+optional = true
+python-versions = ">=3.8.1,<4.0"
+files = [
+    {file = "langchain_core-0.1.12-py3-none-any.whl", hash = "sha256:d11c6262f7a9deff7de8fdf14498b8a951020dfed3a80f2358ab731ad04abef0"},
+    {file = "langchain_core-0.1.12.tar.gz", hash = "sha256:f18e9300e9a07589b3e280e51befbc5a4513f535949406e55eb7a2dc40c3ce66"},
+]
+
+[package.dependencies]
+anyio = ">=3,<5"
+jsonpatch = ">=1.33,<2.0"
+langsmith = ">=0.0.63,<0.1.0"
+packaging = ">=23.2,<24.0"
+pydantic = ">=1,<3"
+PyYAML = ">=5.3"
+requests = ">=2,<3"
+tenacity = ">=8.1.0,<9.0.0"
+
+[package.extras]
+extended-testing = ["jinja2 (>=3,<4)"]
+
+[[package]]
+name = "langchain-mistralai"
+version = "0.0.3"
+description = "An integration package connecting Mistral and LangChain"
+optional = true
+python-versions = ">=3.8.1,<4.0"
+files = [
+    {file = "langchain_mistralai-0.0.3-py3-none-any.whl", hash = "sha256:ebb8ba3d7978b5ee16f7e09512ffa434e00bc9863f1537f1a5f5203882d99619"},
+    {file = "langchain_mistralai-0.0.3.tar.gz", hash = "sha256:2e45ee0118df8e4b5577ce8c4f89743059801e473f40a8b7c89cb99dd715f423"},
+]
+
+[package.dependencies]
+langchain-core = ">=0.1,<0.2"
+mistralai = ">=0.0.11,<0.0.12"
+
 [[package]]
 name = "langdetect"
 version = "1.0.9"
@@ -3458,6 +3498,22 @@ files = [
 certifi = "*"
 urllib3 = "*"
 
+[[package]]
+name = "mistralai"
+version = "0.0.11"
+description = ""
+optional = true
+python-versions = ">=3.8,<4.0"
+files = [
+    {file = "mistralai-0.0.11-py3-none-any.whl", hash = "sha256:fb2a240a3985420c4e7db48eb5077d6d6dbc5e83cac0dd948c20342fb48087ee"},
+    {file = "mistralai-0.0.11.tar.gz", hash = "sha256:383072715531198305dab829ab3749b64933bbc2549354f3c9ebc43c17b912cf"},
+]
+
+[package.dependencies]
+httpx = ">=0.25.2,<0.26.0"
+orjson = ">=3.9.10,<4.0.0"
+pydantic = ">=2.5.2,<3.0.0"
+
 [[package]]
 name = "mock"
 version = "5.1.0"
@@ -4294,6 +4350,65 @@ files = [
     {file = "opentelemetry_semantic_conventions-0.42b0.tar.gz", hash = "sha256:44ae67a0a3252a05072877857e5cc1242c98d4cf12870159f1a94bec800d38ec"},
 ]
 
+[[package]]
+name = "orjson"
+version = "3.9.12"
+description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
+optional = true
+python-versions = ">=3.8"
+files = [
+    {file = "orjson-3.9.12-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6b4e2bed7d00753c438e83b613923afdd067564ff7ed696bfe3a7b073a236e07"},
+    {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd1b8ec63f0bf54a50b498eedeccdca23bd7b658f81c524d18e410c203189365"},
+    {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ab8add018a53665042a5ae68200f1ad14c7953fa12110d12d41166f111724656"},
+    {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12756a108875526b76e505afe6d6ba34960ac6b8c5ec2f35faf73ef161e97e07"},
+    {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:890e7519c0c70296253660455f77e3a194554a3c45e42aa193cdebc76a02d82b"},
+    {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d664880d7f016efbae97c725b243b33c2cbb4851ddc77f683fd1eec4a7894146"},
+    {file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cfdaede0fa5b500314ec7b1249c7e30e871504a57004acd116be6acdda3b8ab3"},
+    {file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6492ff5953011e1ba9ed1bf086835fd574bd0a3cbe252db8e15ed72a30479081"},
+    {file = "orjson-3.9.12-cp310-none-win32.whl", hash = "sha256:29bf08e2eadb2c480fdc2e2daae58f2f013dff5d3b506edd1e02963b9ce9f8a9"},
+    {file = "orjson-3.9.12-cp310-none-win_amd64.whl", hash = "sha256:0fc156fba60d6b50743337ba09f052d8afc8b64595112996d22f5fce01ab57da"},
+    {file = "orjson-3.9.12-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:2849f88a0a12b8d94579b67486cbd8f3a49e36a4cb3d3f0ab352c596078c730c"},
+    {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3186b18754befa660b31c649a108a915493ea69b4fc33f624ed854ad3563ac65"},
+    {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbbf313c9fb9d4f6cf9c22ced4b6682230457741daeb3d7060c5d06c2e73884a"},
+    {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99e8cd005b3926c3db9b63d264bd05e1bf4451787cc79a048f27f5190a9a0311"},
+    {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59feb148392d9155f3bfed0a2a3209268e000c2c3c834fb8fe1a6af9392efcbf"},
+    {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4ae815a172a1f073b05b9e04273e3b23e608a0858c4e76f606d2d75fcabde0c"},
+    {file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ed398f9a9d5a1bf55b6e362ffc80ac846af2122d14a8243a1e6510a4eabcb71e"},
+    {file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d3cfb76600c5a1e6be91326b8f3b83035a370e727854a96d801c1ea08b708073"},
+    {file = "orjson-3.9.12-cp311-none-win32.whl", hash = "sha256:a2b6f5252c92bcab3b742ddb3ac195c0fa74bed4319acd74f5d54d79ef4715dc"},
+    {file = "orjson-3.9.12-cp311-none-win_amd64.whl", hash = "sha256:c95488e4aa1d078ff5776b58f66bd29d628fa59adcb2047f4efd3ecb2bd41a71"},
+    {file = "orjson-3.9.12-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d6ce2062c4af43b92b0221ed4f445632c6bf4213f8a7da5396a122931377acd9"},
+    {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:950951799967558c214cd6cceb7ceceed6f81d2c3c4135ee4a2c9c69f58aa225"},
+    {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2dfaf71499d6fd4153f5c86eebb68e3ec1bf95851b030a4b55c7637a37bbdee4"},
+    {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:659a8d7279e46c97661839035a1a218b61957316bf0202674e944ac5cfe7ed83"},
+    {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af17fa87bccad0b7f6fd8ac8f9cbc9ee656b4552783b10b97a071337616db3e4"},
+    {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd52dec9eddf4c8c74392f3fd52fa137b5f2e2bed1d9ae958d879de5f7d7cded"},
+    {file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:640e2b5d8e36b970202cfd0799d11a9a4ab46cf9212332cd642101ec952df7c8"},
+    {file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:daa438bd8024e03bcea2c5a92cd719a663a58e223fba967296b6ab9992259dbf"},
+    {file = "orjson-3.9.12-cp312-none-win_amd64.whl", hash = "sha256:1bb8f657c39ecdb924d02e809f992c9aafeb1ad70127d53fb573a6a6ab59d549"},
+    {file = "orjson-3.9.12-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:f4098c7674901402c86ba6045a551a2ee345f9f7ed54eeffc7d86d155c8427e5"},
+    {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5586a533998267458fad3a457d6f3cdbddbcce696c916599fa8e2a10a89b24d3"},
+    {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:54071b7398cd3f90e4bb61df46705ee96cb5e33e53fc0b2f47dbd9b000e238e1"},
+    {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:67426651faa671b40443ea6f03065f9c8e22272b62fa23238b3efdacd301df31"},
+    {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4a0cd56e8ee56b203abae7d482ac0d233dbfb436bb2e2d5cbcb539fe1200a312"},
+    {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a84a0c3d4841a42e2571b1c1ead20a83e2792644c5827a606c50fc8af7ca4bee"},
+    {file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:09d60450cda3fa6c8ed17770c3a88473a16460cd0ff2ba74ef0df663b6fd3bb8"},
+    {file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bc82a4db9934a78ade211cf2e07161e4f068a461c1796465d10069cb50b32a80"},
+    {file = "orjson-3.9.12-cp38-none-win32.whl", hash = "sha256:61563d5d3b0019804d782137a4f32c72dc44c84e7d078b89d2d2a1adbaa47b52"},
+    {file = "orjson-3.9.12-cp38-none-win_amd64.whl", hash = "sha256:410f24309fbbaa2fab776e3212a81b96a1ec6037259359a32ea79fbccfcf76aa"},
+    {file = "orjson-3.9.12-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e773f251258dd82795fd5daeac081d00b97bacf1548e44e71245543374874bcf"},
+    {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b159baecfda51c840a619948c25817d37733a4d9877fea96590ef8606468b362"},
+    {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:975e72e81a249174840d5a8df977d067b0183ef1560a32998be340f7e195c730"},
+    {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:06e42e899dde61eb1851a9fad7f1a21b8e4be063438399b63c07839b57668f6c"},
+    {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c157e999e5694475a5515942aebeed6e43f7a1ed52267c1c93dcfde7d78d421"},
+    {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dde1bc7c035f2d03aa49dc8642d9c6c9b1a81f2470e02055e76ed8853cfae0c3"},
+    {file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b0e9d73cdbdad76a53a48f563447e0e1ce34bcecef4614eb4b146383e6e7d8c9"},
+    {file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:96e44b21fe407b8ed48afbb3721f3c8c8ce17e345fbe232bd4651ace7317782d"},
+    {file = "orjson-3.9.12-cp39-none-win32.whl", hash = "sha256:cbd0f3555205bf2a60f8812133f2452d498dbefa14423ba90fe89f32276f7abf"},
+    {file = "orjson-3.9.12-cp39-none-win_amd64.whl", hash = "sha256:03ea7ee7e992532c2f4a06edd7ee1553f0644790553a118e003e3c405add41fa"},
+    {file = "orjson-3.9.12.tar.gz", hash = "sha256:da908d23a3b3243632b523344403b128722a5f45e278a8343c2bb67538dff0e4"},
+]
+
 [[package]]
 name = "overrides"
 version = "7.4.0"
@@ -8110,6 +8225,7 @@ googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-
 huggingface-hub = ["huggingface_hub"]
 llama2 = ["replicate"]
 milvus = ["pymilvus"]
+mistralai = ["langchain-mistralai"]
 modal = ["modal"]
 mysql = ["mysql-connector-python"]
 opensearch = ["opensearch-py"]
@@ -8130,4 +8246,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.9,<3.12"
-content-hash = "02bd85e14374a9dc9b59523b8fb4baea7068251976ba7f87722cac94a9974ccc"
+content-hash = "cb0da55af7c61300bb321770ed319c900b6b3ba3865421d63eb9120beb73d06c"

+ 2 - 0
pyproject.toml

@@ -149,6 +149,7 @@ google-auth-oauthlib = { version = "^1.2.0", optional = true }
 google-auth = { version = "^2.25.2", optional = true }
 google-auth-httplib2 = { version = "^0.2.0", optional = true }
 google-api-core = { version = "^2.15.0", optional = true }
+langchain-mistralai = { version = "^0.0.3", optional = true }
 
 [tool.poetry.group.dev.dependencies]
 black = "^23.3.0"
@@ -214,6 +215,7 @@ rss_feed = [
 google = ["google-generativeai"]
 modal = ["modal"]
 dropbox = ["dropbox"]
+mistralai = ["langchain-mistralai"]
 
 [tool.poetry.group.docs.dependencies]
 

+ 60 - 0
tests/llm/test_mistralai.py

@@ -0,0 +1,60 @@
+import pytest
+
+from embedchain.config import BaseLlmConfig
+from embedchain.llm.mistralai import MistralAILlm
+
+
+@pytest.fixture
+def mistralai_llm_config(monkeypatch):
+    monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
+    yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
+    monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
+
+
+def test_mistralai_llm_init_missing_api_key(monkeypatch):
+    monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
+    with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
+        MistralAILlm()
+
+
+def test_mistralai_llm_init(monkeypatch):
+    monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
+    llm = MistralAILlm()
+    assert llm is not None
+
+
+def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
+    def mock_get_answer(prompt, config):
+        return "Generated Text"
+
+    monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
+    llm = MistralAILlm(config=mistralai_llm_config)
+    result = llm.get_llm_model_answer("test prompt")
+
+    assert result == "Generated Text"
+
+
+def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
+    mistralai_llm_config.system_prompt = "Test system prompt"
+    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+    llm = MistralAILlm(config=mistralai_llm_config)
+    result = llm.get_llm_model_answer("test prompt")
+
+    assert result == "Generated Text"
+
+
+def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
+    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+    llm = MistralAILlm(config=mistralai_llm_config)
+    result = llm.get_llm_model_answer("")
+
+    assert result == "Generated Text"
+
+
+def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
+    mistralai_llm_config.system_prompt = None
+    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+    llm = MistralAILlm(config=mistralai_llm_config)
+    result = llm.get_llm_model_answer("test prompt")
+
+    assert result == "Generated Text"