clarifai.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. from typing import Optional, Union
  3. from embedchain.config import BaseEmbedderConfig
  4. from embedchain.embedder.base import BaseEmbedder
  5. from chromadb import EmbeddingFunction, Embeddings
  6. class ClarifaiEmbeddingFunction(EmbeddingFunction):
  7. def __init__(self, config: BaseEmbedderConfig) -> None:
  8. super().__init__()
  9. try:
  10. from clarifai.client.model import Model
  11. from clarifai.client.input import Inputs
  12. except ModuleNotFoundError:
  13. raise ModuleNotFoundError(
  14. "The required dependencies for ClarifaiEmbeddingFunction are not installed."
  15. 'Please install with `pip install --upgrade "embedchain[clarifai]"`'
  16. ) from None
  17. self.config = config
  18. self.api_key = config.api_key or os.getenv("CLARIFAI_PAT")
  19. self.model = config.model
  20. self.model_obj = Model(url=self.model, pat=self.api_key)
  21. self.input_obj = Inputs(pat=self.api_key)
  22. def __call__(self, input: Union[str, list[str]]) -> Embeddings:
  23. if isinstance(input, str):
  24. input = [input]
  25. batch_size = 32
  26. embeddings = []
  27. try:
  28. for i in range(0, len(input), batch_size):
  29. batch = input[i : i + batch_size]
  30. input_batch = [
  31. self.input_obj.get_text_input(input_id=str(id), raw_text=inp) for id, inp in enumerate(batch)
  32. ]
  33. response = self.model_obj.predict(input_batch)
  34. embeddings.extend([list(output.data.embeddings[0].vector) for output in response.outputs])
  35. except Exception as e:
  36. print(f"Predict failed, exception: {e}")
  37. return embeddings
  38. class ClarifaiEmbedder(BaseEmbedder):
  39. def __init__(self, config: Optional[BaseEmbedderConfig] = None):
  40. super().__init__(config)
  41. embedding_func = ClarifaiEmbeddingFunction(config=self.config)
  42. self.set_embedding_fn(embedding_fn=embedding_func)