فهرست منبع

[Feature] Add support for Groq LLMs (#1284)

Deshraj Yadav 1 سال پیش
والد
کامیت
92dd7edb57
6فایلهای تغییر یافته به همراه105 افزوده شده و 3 حذف شده
  1. 57 0
      docs/components/llms.mdx
  2. 1 0
      embedchain/factory.py
  3. 43 0
      embedchain/llm/groq.py
  4. 1 2
      embedchain/llm/openai.py
  5. 2 0
      embedchain/utils/misc.py
  6. 1 1
      pyproject.toml

+ 57 - 0
docs/components/llms.mdx

@@ -22,6 +22,7 @@ Embedchain comes with built-in support for various popular large language models
   <Card title="Vertex AI" href="#vertex-ai"></Card>
   <Card title="Mistral AI" href="#mistral-ai"></Card>
   <Card title="AWS Bedrock" href="#aws-bedrock"></Card>
+  <Card title="Groq" href="#groq"></Card>
 </CardGroup>
 
 ## OpenAI
@@ -654,4 +655,60 @@ llm:
 </Note>
 
 <br/ >
+
+## Groq
+
+[Groq](https://groq.com/) is the creator of the world's first Language Processing Unit (LPU), providing exceptional speed performance for AI workloads running on their LPU Inference Engine.
+
+
+### Usage
+
+In order to use LLMs from Groq, go to their [platform](https://console.groq.com/keys) and get the API key.
+
+Set the API key as `GROQ_API_KEY` environment variable or pass in your app configuration to use the model as given below in the example.
+
+<CodeGroup>
+
+```python main.py
+import os
+from embedchain import App
+
+# Set your API key here or pass as the environment variable
+groq_api_key = "gsk_xxxx"
+
+config = {
+    "llm": {
+        "provider": "groq",
+        "config": {
+            "model": "mixtral-8x7b-32768",
+            "api_key": groq_api_key,
+            "stream": True
+        }
+    }
+}
+
+app = App.from_config(config=config)
+# Add your data source here
+app.add("https://docs.embedchain.ai/sitemap.xml", data_type="sitemap")
+app.query("Write a poem about Embedchain")
+
+# In the realm of data, vast and wide,
+# Embedchain stands with knowledge as its guide.
+# A platform open, for all to try,
+# Building bots that can truly fly.
+
+# With REST API, data in reach,
+# Deployment a breeze, as easy as a speech.
+# Updating data sources, anytime, anyday,
+# Embedchain's power, never sway.
+
+# A knowledge base, an assistant so grand,
+# Connecting to platforms, near and far.
+# Discord, WhatsApp, Slack, and more,
+# Embedchain's potential, never a bore.
+```
+</CodeGroup>
+
+<br/ >
+
 <Snippet file="missing-llm-tip.mdx" />

+ 1 - 0
embedchain/factory.py

@@ -23,6 +23,7 @@ class LlmFactory:
         "google": "embedchain.llm.google.GoogleLlm",
         "aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
         "mistralai": "embedchain.llm.mistralai.MistralAILlm",
+        "groq": "embedchain.llm.groq.GroqLlm",
     }
     provider_to_config_class = {
         "embedchain": "embedchain.config.llm.base.BaseLlmConfig",

+ 43 - 0
embedchain/llm/groq.py

@@ -0,0 +1,43 @@
+import os
+from typing import Optional
+
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+from langchain.schema import HumanMessage, SystemMessage
+
+try:
+    from langchain_groq import ChatGroq
+except ImportError:
+    raise ImportError("Groq requires extra dependencies. Install with `pip install langchain-groq`") from None
+
+
+from embedchain.config import BaseLlmConfig
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.llm.base import BaseLlm
+
+
+@register_deserializable
+class GroqLlm(BaseLlm):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        super().__init__(config=config)
+
+    def get_llm_model_answer(self, prompt) -> str:
+        response = self._get_answer(prompt, self.config)
+        return response
+
+    def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
+        messages = []
+        if config.system_prompt:
+            messages.append(SystemMessage(content=config.system_prompt))
+        messages.append(HumanMessage(content=prompt))
+        api_key = config.api_key or os.environ["GROQ_API_KEY"]
+        kwargs = {
+            "model_name": config.model or "mixtral-8x7b-32768",
+            "temperature": config.temperature,
+            "groq_api_key": api_key,
+        }
+        if config.stream:
+            callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
+            chat = ChatGroq(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
+        else:
+            chat = ChatGroq(**kwargs)
+        return chat.invoke(messages).content

+ 1 - 2
embedchain/llm/openai.py

@@ -58,8 +58,7 @@ class OpenAILlm(BaseLlm):
         messages: list[BaseMessage],
     ) -> str:
         from langchain.output_parsers.openai_tools import JsonOutputToolsParser
-        from langchain_core.utils.function_calling import \
-            convert_to_openai_tool
+        from langchain_core.utils.function_calling import convert_to_openai_tool
 
         openai_tools = [convert_to_openai_tool(tools)]
         chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())

+ 2 - 0
embedchain/utils/misc.py

@@ -406,9 +406,11 @@ def validate_config(config_data):
                     "aws_bedrock",
                     "mistralai",
                     "vllm",
+                    "groq",
                 ),
                 Optional("config"): {
                     Optional("model"): str,
+                    Optional("model_name"): str,
                     Optional("number_documents"): int,
                     Optional("temperature"): float,
                     Optional("max_tokens"): int,

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.85"
+version = "0.1.86"
 description = "Simplest open source retrieval(RAG) framework"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",