Browse Source

[Feature] Add support for AWS Bedrock LLM (#1189)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 năm trước cách đây
mục cha
commit
069d265338

+ 2 - 1
docs/api-reference/advanced/configuration.mdx

@@ -200,9 +200,10 @@ Alright, let's dive into what each key means in the yaml config above:
         - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
         - `prompt` (String): A prompt for the model to follow when generating responses, requires `$context` and `$query` variables.
         - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
-        -  `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
+        - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
         - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1
         - `api_key` (String): The API key for the language model.
+        - `model_kwargs` (Dict): Keyword arguments to pass to the language model. Used for `aws_bedrock` provider, since it requires different arguments for each model.
 3. `vectordb` Section:
     - `provider` (String): The provider for the vector database, set to 'chroma'. You can find the full list of vector database providers in [our docs](/components/vector-databases).
     - `config`:

+ 43 - 5
docs/components/llms.mdx

@@ -21,6 +21,7 @@ Embedchain comes with built-in support for various popular large language models
   <Card title="Llama2" href="#llama2"></Card>
   <Card title="Vertex AI" href="#vertex-ai"></Card>
   <Card title="Mistral AI" href="#mistral-ai"></Card>
+  <Card title="AWS Bedrock" href="#aws-bedrock"></Card>
 </CardGroup>
 
 ## OpenAI
@@ -627,11 +628,8 @@ llm:
 Obtain the Mistral AI api key from their [console](https://console.mistral.ai/).
 
 <CodeGroup>
-
-```python main.py
-import os
-from embedchain import App
-
+ 
+ ```python main.py
 os.environ["MISTRAL_API_KEY"] = "xxx"
 
 app = App.from_config(config_path="config.yaml")
@@ -663,5 +661,45 @@ embedder:
 ```
 </CodeGroup>
 
+
+## AWS Bedrock
+
+### Setup
+- Before using the AWS Bedrock LLM, make sure you have the appropriate model access from [Bedrock Console](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess).
+- You will also need `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` to authenticate the API with AWS. You can find these in your [AWS Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users).
+
+
+### Usage
+
+<CodeGroup>
+
+```python main.py
+import os
+from embedchain import App
+
+os.environ["AWS_ACCESS_KEY_ID"] = "xxx"
+os.environ["AWS_SECRET_ACCESS_KEY"] = "xxx"
+
+app = App.from_config(config_path="config.yaml")
+```
+
+```yaml config.yaml
+llm:
+  provider: aws_bedrock
+  config:
+    model: amazon.titan-text-express-v1
+    # check notes below for model_kwargs
+    model_kwargs:
+      temperature: 0.5
+      topP: 1
+      maxTokenCount: 1000
+```
+</CodeGroup>
+
+<br />
+<Note>
+  The model arguments are different for each providers. Please refer to the [AWS Bedrock Documentation](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers) to find the appropriate arguments for your model.
+</Note>
+
 <br/ >
 <Snippet file="missing-llm-tip.mdx" />

+ 1 - 0
embedchain/factory.py

@@ -21,6 +21,7 @@ class LlmFactory:
         "openai": "embedchain.llm.openai.OpenAILlm",
         "vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
         "google": "embedchain.llm.google.GoogleLlm",
+        "aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
         "mistralai": "embedchain.llm.mistralai.MistralAILlm",
     }
     provider_to_config_class = {

+ 48 - 0
embedchain/llm/aws_bedrock.py

@@ -0,0 +1,48 @@
+from typing import Optional
+
+from langchain.llms import Bedrock
+
+from embedchain.config import BaseLlmConfig
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.llm.base import BaseLlm
+
+
+@register_deserializable
+class AWSBedrockLlm(BaseLlm):
+    def __init__(self, config: Optional[BaseLlmConfig] = None):
+        super().__init__(config)
+
+    def get_llm_model_answer(self, prompt) -> str:
+        response = self._get_answer(prompt, self.config)
+        return response
+
+    def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
+        try:
+            import boto3
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "The required dependencies for AWSBedrock are not installed."
+                'Please install with `pip install --upgrade "embedchain[aws-bedrock]"`'
+            ) from None
+
+        self.boto_client = boto3.client("bedrock-runtime", "us-west-2")
+
+        kwargs = {
+            "model_id": config.model or "amazon.titan-text-express-v1",
+            "client": self.boto_client,
+            "model_kwargs": config.model_kwargs
+            or {
+                "temperature": config.temperature,
+            },
+        }
+
+        if config.stream:
+            from langchain.callbacks.streaming_stdout import \
+                StreamingStdOutCallbackHandler
+
+            callbacks = [StreamingStdOutCallbackHandler()]
+            llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)
+        else:
+            llm = Bedrock(**kwargs)
+
+        return llm(prompt)

+ 2 - 0
embedchain/utils/misc.py

@@ -406,6 +406,7 @@ def validate_config(config_data):
                     "llama2",
                     "vertexai",
                     "google",
+                    "aws_bedrock",
                     "mistralai",
                 ),
                 Optional("config"): {
@@ -423,6 +424,7 @@ def validate_config(config_data):
                     Optional("query_type"): str,
                     Optional("api_key"): str,
                     Optional("endpoint"): str,
+                    Optional("model_kwargs"): dict,
                 },
             },
             Optional("vectordb"): {

+ 72 - 2
poetry.lock

@@ -383,6 +383,47 @@ files = [
     {file = "blinker-1.6.3.tar.gz", hash = "sha256:152090d27c1c5c722ee7e48504b02d76502811ce02e1523553b4cf8c8b3d3a8d"},
 ]
 
+[[package]]
+name = "boto3"
+version = "1.34.22"
+description = "The AWS SDK for Python"
+optional = true
+python-versions = ">= 3.8"
+files = [
+    {file = "boto3-1.34.22-py3-none-any.whl", hash = "sha256:5909cd1393143576265c692e908a9ae495492c04a0ffd4bae8578adc2e44729e"},
+    {file = "boto3-1.34.22.tar.gz", hash = "sha256:a98c0b86f6044ff8314cc2361e1ef574d674318313ab5606ccb4a6651c7a3f8c"},
+]
+
+[package.dependencies]
+botocore = ">=1.34.22,<1.35.0"
+jmespath = ">=0.7.1,<2.0.0"
+s3transfer = ">=0.10.0,<0.11.0"
+
+[package.extras]
+crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
+
+[[package]]
+name = "botocore"
+version = "1.34.22"
+description = "Low-level, data-driven core of boto 3."
+optional = true
+python-versions = ">= 3.8"
+files = [
+    {file = "botocore-1.34.22-py3-none-any.whl", hash = "sha256:e5f7775975b9213507fbcf846a96b7a2aec2a44fc12a44585197b014a4ab0889"},
+    {file = "botocore-1.34.22.tar.gz", hash = "sha256:c47ba4286c576150d1b6ca6df69a87b5deff3d23bd84da8bcf8431ebac3c40ba"},
+]
+
+[package.dependencies]
+jmespath = ">=0.7.1,<2.0.0"
+python-dateutil = ">=2.1,<3.0.0"
+urllib3 = [
+    {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""},
+    {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""},
+]
+
+[package.extras]
+crt = ["awscrt (==0.19.19)"]
+
 [[package]]
 name = "brotli"
 version = "1.1.0"
@@ -2810,6 +2851,17 @@ MarkupSafe = ">=2.0"
 [package.extras]
 i18n = ["Babel (>=2.7)"]
 
+[[package]]
+name = "jmespath"
+version = "1.0.1"
+description = "JSON Matching Expressions"
+optional = true
+python-versions = ">=3.7"
+files = [
+    {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"},
+    {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
+]
+
 [[package]]
 name = "joblib"
 version = "1.3.2"
@@ -4211,9 +4263,9 @@ files = [
 [package.dependencies]
 numpy = [
     {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
+    {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
     {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
     {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
-    {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
     {version = ">=1.23.5", markers = "python_version >= \"3.11\""},
 ]
 
@@ -6091,6 +6143,23 @@ files = [
     {file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"},
 ]
 
+[[package]]
+name = "s3transfer"
+version = "0.10.0"
+description = "An Amazon S3 Transfer Manager"
+optional = true
+python-versions = ">= 3.8"
+files = [
+    {file = "s3transfer-0.10.0-py3-none-any.whl", hash = "sha256:3cdb40f5cfa6966e812209d0994f2a4709b561c88e90cf00c2696d2df4e56b2e"},
+    {file = "s3transfer-0.10.0.tar.gz", hash = "sha256:d0c8bbf672d5eebbe4e57945e23b972d963f07d82f661cabf678a5c88831595b"},
+]
+
+[package.dependencies]
+botocore = ">=1.33.2,<2.0a.0"
+
+[package.extras]
+crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
+
 [[package]]
 name = "safetensors"
 version = "0.4.0"
@@ -8213,6 +8282,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
 testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
 
 [extras]
+aws-bedrock = ["boto3"]
 cohere = ["cohere"]
 dataloaders = ["docx2txt", "duckduckgo-search", "pytube", "sentence-transformers", "unstructured", "youtube-transcript-api"]
 discord = ["discord"]
@@ -8246,4 +8316,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.9,<3.12"
-content-hash = "cb0da55af7c61300bb321770ed319c900b6b3ba3865421d63eb9120beb73d06c"
+content-hash = "bbcf32e87c0784d031fb6cf9bd89655375839da0660b8feb2026ffdd971623d7"

+ 2 - 0
pyproject.toml

@@ -149,6 +149,7 @@ google-auth-oauthlib = { version = "^1.2.0", optional = true }
 google-auth = { version = "^2.25.2", optional = true }
 google-auth-httplib2 = { version = "^0.2.0", optional = true }
 google-api-core = { version = "^2.15.0", optional = true }
+boto3 = { version = "^1.34.20", optional = true }
 langchain-mistralai = { version = "^0.0.3", optional = true }
 
 [tool.poetry.group.dev.dependencies]
@@ -215,6 +216,7 @@ rss_feed = [
 google = ["google-generativeai"]
 modal = ["modal"]
 dropbox = ["dropbox"]
+aws_bedrock = ["boto3"]
 mistralai = ["langchain-mistralai"]
 
 [tool.poetry.group.docs.dependencies]

+ 56 - 0
tests/llm/test_aws_bedrock.py

@@ -0,0 +1,56 @@
+import pytest
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+
+from embedchain.config import BaseLlmConfig
+from embedchain.llm.aws_bedrock import AWSBedrockLlm
+
+
+@pytest.fixture
+def config(monkeypatch):
+    monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
+    monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
+    monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
+    config = BaseLlmConfig(
+        model="amazon.titan-text-express-v1",
+        model_kwargs={
+            "temperature": 0.5,
+            "topP": 1,
+            "maxTokenCount": 1000,
+        },
+    )
+    yield config
+    monkeypatch.delenv("AWS_ACCESS_KEY_ID")
+    monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
+    monkeypatch.delenv("OPENAI_API_KEY")
+
+
+def test_get_llm_model_answer(config, mocker):
+    mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
+
+    llm = AWSBedrockLlm(config)
+    answer = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    mocked_get_answer.assert_called_once_with("Test query", config)
+
+
+def test_get_llm_model_answer_empty_prompt(config, mocker):
+    mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
+
+    llm = AWSBedrockLlm(config)
+    answer = llm.get_llm_model_answer("")
+
+    assert answer == "Test answer"
+    mocked_get_answer.assert_called_once_with("", config)
+
+
+def test_get_llm_model_answer_with_streaming(config, mocker):
+    config.stream = True
+    mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
+
+    llm = AWSBedrockLlm(config)
+    llm.get_llm_model_answer("Test query")
+
+    mocked_bedrock_chat.assert_called_once()
+    callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list]
+    assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)