mistralai.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import os
  2. from typing import Optional, Union
  3. from chromadb import EmbeddingFunction, Embeddings
  4. from embedchain.config import BaseEmbedderConfig
  5. from embedchain.embedder.base import BaseEmbedder
  6. from embedchain.models import VectorDimensions
  7. class MistralAIEmbeddingFunction(EmbeddingFunction):
  8. def __init__(self, config: BaseEmbedderConfig) -> None:
  9. super().__init__()
  10. try:
  11. from langchain_mistralai import MistralAIEmbeddings
  12. except ModuleNotFoundError:
  13. raise ModuleNotFoundError(
  14. "The required dependencies for MistralAI are not installed."
  15. 'Please install with `pip install --upgrade "embedchain[mistralai]"`'
  16. ) from None
  17. self.config = config
  18. api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY")
  19. self.client = MistralAIEmbeddings(mistral_api_key=api_key)
  20. self.client.model = self.config.model
  21. def __call__(self, input: Union[list[str], str]) -> Embeddings:
  22. if isinstance(input, str):
  23. input_ = [input]
  24. else:
  25. input_ = input
  26. response = self.client.embed_documents(input_)
  27. return response
  28. class MistralAIEmbedder(BaseEmbedder):
  29. def __init__(self, config: Optional[BaseEmbedderConfig] = None):
  30. super().__init__(config)
  31. if self.config.model is None:
  32. self.config.model = "mistral-embed"
  33. embedding_fn = MistralAIEmbeddingFunction(config=self.config)
  34. self.set_embedding_fn(embedding_fn=embedding_fn)
  35. vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value
  36. self.set_vector_dimension(vector_dimension=vector_dimension)