mistralai.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. from typing import Any, Optional
  3. from embedchain.config import BaseLlmConfig
  4. from embedchain.helpers.json_serializable import register_deserializable
  5. from embedchain.llm.base import BaseLlm
  6. @register_deserializable
  7. class MistralAILlm(BaseLlm):
  8. def __init__(self, config: Optional[BaseLlmConfig] = None):
  9. super().__init__(config)
  10. if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
  11. raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
  12. def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
  13. if self.config.token_usage:
  14. response, token_info = self._get_answer(prompt, self.config)
  15. model_name = "mistralai/" + self.config.model
  16. if model_name not in self.config.model_pricing_map:
  17. raise ValueError(
  18. f"Model {model_name} not found in `model_prices_and_context_window.json`. \
  19. You can disable token usage by setting `token_usage` to False."
  20. )
  21. total_cost = (
  22. self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
  23. ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
  24. response_token_info = {
  25. "prompt_tokens": token_info["prompt_tokens"],
  26. "completion_tokens": token_info["completion_tokens"],
  27. "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
  28. "total_cost": round(total_cost, 10),
  29. "cost_currency": "USD",
  30. }
  31. return response, response_token_info
  32. return self._get_answer(prompt, self.config)
  33. @staticmethod
  34. def _get_answer(prompt: str, config: BaseLlmConfig):
  35. try:
  36. from langchain_core.messages import HumanMessage, SystemMessage
  37. from langchain_mistralai.chat_models import ChatMistralAI
  38. except ModuleNotFoundError:
  39. raise ModuleNotFoundError(
  40. "The required dependencies for MistralAI are not installed."
  41. 'Please install with `pip install --upgrade "embedchain[mistralai]"`'
  42. ) from None
  43. api_key = config.api_key or os.getenv("MISTRAL_API_KEY")
  44. client = ChatMistralAI(mistral_api_key=api_key)
  45. messages = []
  46. if config.system_prompt:
  47. messages.append(SystemMessage(content=config.system_prompt))
  48. messages.append(HumanMessage(content=prompt))
  49. kwargs = {
  50. "model": config.model or "mistral-tiny",
  51. "temperature": config.temperature,
  52. "max_tokens": config.max_tokens,
  53. "top_p": config.top_p,
  54. }
  55. # TODO: Add support for streaming
  56. if config.stream:
  57. answer = ""
  58. for chunk in client.stream(**kwargs, input=messages):
  59. answer += chunk.content
  60. return answer
  61. else:
  62. chat_response = client.invoke(**kwargs, input=messages)
  63. if config.token_usage:
  64. return chat_response.content, chat_response.response_metadata["token_usage"]
  65. return chat_response.content