mistralai.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. from typing import Optional
  3. from embedchain.config import BaseLlmConfig
  4. from embedchain.helpers.json_serializable import register_deserializable
  5. from embedchain.llm.base import BaseLlm
  6. @register_deserializable
  7. class MistralAILlm(BaseLlm):
  8. def __init__(self, config: Optional[BaseLlmConfig] = None):
  9. super().__init__(config)
  10. if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
  11. raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
  12. def get_llm_model_answer(self, prompt):
  13. return MistralAILlm._get_answer(prompt=prompt, config=self.config)
  14. @staticmethod
  15. def _get_answer(prompt: str, config: BaseLlmConfig):
  16. try:
  17. from langchain_core.messages import HumanMessage, SystemMessage
  18. from langchain_mistralai.chat_models import ChatMistralAI
  19. except ModuleNotFoundError:
  20. raise ModuleNotFoundError(
  21. "The required dependencies for MistralAI are not installed."
  22. 'Please install with `pip install --upgrade "embedchain[mistralai]"`'
  23. ) from None
  24. api_key = config.api_key or os.getenv("MISTRAL_API_KEY")
  25. client = ChatMistralAI(mistral_api_key=api_key)
  26. messages = []
  27. if config.system_prompt:
  28. messages.append(SystemMessage(content=config.system_prompt))
  29. messages.append(HumanMessage(content=prompt))
  30. kwargs = {
  31. "model": config.model or "mistral-tiny",
  32. "temperature": config.temperature,
  33. "max_tokens": config.max_tokens,
  34. "top_p": config.top_p,
  35. }
  36. # TODO: Add support for streaming
  37. if config.stream:
  38. answer = ""
  39. for chunk in client.stream(**kwargs, input=messages):
  40. answer += chunk.content
  41. return answer
  42. else:
  43. response = client.invoke(**kwargs, input=messages)
  44. answer = response.content
  45. return answer