12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import os
- from typing import Optional, Union
- from embedchain.config import BaseEmbedderConfig
- from embedchain.embedder.base import BaseEmbedder
- from chromadb import EmbeddingFunction, Embeddings
- class ClarifaiEmbeddingFunction(EmbeddingFunction):
- def __init__(self, config: BaseEmbedderConfig) -> None:
- super().__init__()
- try:
- from clarifai.client.model import Model
- from clarifai.client.input import Inputs
- except ModuleNotFoundError:
- raise ModuleNotFoundError(
- "The required dependencies for ClarifaiEmbeddingFunction are not installed."
- 'Please install with `pip install --upgrade "embedchain[clarifai]"`'
- ) from None
- self.config = config
- self.api_key = config.api_key or os.getenv("CLARIFAI_PAT")
- self.model = config.model
- self.model_obj = Model(url=self.model, pat=self.api_key)
- self.input_obj = Inputs(pat=self.api_key)
- def __call__(self, input: Union[str, list[str]]) -> Embeddings:
- if isinstance(input, str):
- input = [input]
- batch_size = 32
- embeddings = []
- try:
- for i in range(0, len(input), batch_size):
- batch = input[i : i + batch_size]
- input_batch = [
- self.input_obj.get_text_input(input_id=str(id), raw_text=inp) for id, inp in enumerate(batch)
- ]
- response = self.model_obj.predict(input_batch)
- embeddings.extend([list(output.data.embeddings[0].vector) for output in response.outputs])
- except Exception as e:
- print(f"Predict failed, exception: {e}")
- return embeddings
- class ClarifaiEmbedder(BaseEmbedder):
- def __init__(self, config: Optional[BaseEmbedderConfig] = None):
- super().__init__(config)
- embedding_func = ClarifaiEmbeddingFunction(config=self.config)
- self.set_embedding_fn(embedding_fn=embedding_func)
|