vertex_ai.py 975 B

12345678910111213141516171819202122232425262728
  1. import logging
  2. from typing import Optional
  3. from embedchain.config import BaseLlmConfig
  4. from embedchain.helper.json_serializable import register_deserializable
  5. from embedchain.llm.base import BaseLlm
  6. @register_deserializable
  7. class VertexAiLlm(BaseLlm):
  8. def __init__(self, config: Optional[BaseLlmConfig] = None):
  9. super().__init__(config=config)
  10. def get_llm_model_answer(self, prompt):
  11. return VertexAiLlm._get_athrophic_answer(prompt=prompt, config=self.config)
  12. @staticmethod
  13. def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
  14. from langchain.chat_models import ChatVertexAI
  15. chat = ChatVertexAI(temperature=config.temperature, model=config.model)
  16. if config.top_p and config.top_p != 1:
  17. logging.warning("Config option `top_p` is not supported by this model.")
  18. messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
  19. return chat(messages).content