aws_bedrock.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from typing import Optional
  2. from langchain.llms import Bedrock
  3. from embedchain.config import BaseLlmConfig
  4. from embedchain.helpers.json_serializable import register_deserializable
  5. from embedchain.llm.base import BaseLlm
  6. @register_deserializable
  7. class AWSBedrockLlm(BaseLlm):
  8. def __init__(self, config: Optional[BaseLlmConfig] = None):
  9. super().__init__(config)
  10. def get_llm_model_answer(self, prompt) -> str:
  11. response = self._get_answer(prompt, self.config)
  12. return response
  13. def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
  14. try:
  15. import boto3
  16. except ModuleNotFoundError:
  17. raise ModuleNotFoundError(
  18. "The required dependencies for AWSBedrock are not installed."
  19. 'Please install with `pip install --upgrade "embedchain[aws-bedrock]"`'
  20. ) from None
  21. self.boto_client = boto3.client("bedrock-runtime", "us-west-2")
  22. kwargs = {
  23. "model_id": config.model or "amazon.titan-text-express-v1",
  24. "client": self.boto_client,
  25. "model_kwargs": config.model_kwargs
  26. or {
  27. "temperature": config.temperature,
  28. },
  29. }
  30. if config.stream:
  31. from langchain.callbacks.streaming_stdout import \
  32. StreamingStdOutCallbackHandler
  33. callbacks = [StreamingStdOutCallbackHandler()]
  34. llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)
  35. else:
  36. llm = Bedrock(**kwargs)
  37. return llm(prompt)