1234567891011121314151617181920212223242526272829303132333435363738 |
- from typing import Optional, Union
- import google.generativeai as genai
- from chromadb import EmbeddingFunction, Embeddings
- from embedchain.config.embedder.google import GoogleAIEmbedderConfig
- from embedchain.embedder.base import BaseEmbedder
- from embedchain.models import VectorDimensions
- class GoogleAIEmbeddingFunction(EmbeddingFunction):
- def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None:
- super().__init__()
- self.config = config or GoogleAIEmbedderConfig()
- def __call__(self, input: Union[list[str], str]) -> Embeddings:
- model = self.config.model
- title = self.config.title
- task_type = self.config.task_type
- if isinstance(input, str):
- input_ = [input]
- else:
- input_ = input
- data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
- embeddings = data["embedding"]
- if isinstance(input_, str):
- embeddings = [embeddings]
- return embeddings
- class GoogleAIEmbedder(BaseEmbedder):
- def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None):
- super().__init__(config)
- embedding_fn = GoogleAIEmbeddingFunction(config=config)
- self.set_embedding_fn(embedding_fn=embedding_fn)
- vector_dimension = self.config.vector_dimension or VectorDimensions.GOOGLE_AI.value
- self.set_vector_dimension(vector_dimension=vector_dimension)
|