antrophic_llm.py 1010 B

1234567891011121314151617181920212223242526272829
  1. import logging
  2. from typing import Optional
  3. from embedchain.config import BaseLlmConfig
  4. from embedchain.llm.base_llm import BaseLlm
  5. from embedchain.helper_classes.json_serializable import register_deserializable
  6. @register_deserializable
  7. class AntrophicLlm(BaseLlm):
  8. def __init__(self, config: Optional[BaseLlmConfig] = None):
  9. super().__init__(config=config)
  10. def get_llm_model_answer(self, prompt):
  11. return AntrophicLlm._get_athrophic_answer(prompt=prompt, config=self.config)
  12. @staticmethod
  13. def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
  14. from langchain.chat_models import ChatAnthropic
  15. chat = ChatAnthropic(temperature=config.temperature, model=config.model)
  16. if config.max_tokens and config.max_tokens != 1000:
  17. logging.warning("Config option `max_tokens` is not supported by this model.")
  18. messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
  19. return chat(messages).content