google.py 1.2 KB

12345678910111213141516171819202122232425262728293031
  1. from typing import Optional
  2. import google.generativeai as genai
  3. from chromadb import EmbeddingFunction, Embeddings
  4. from embedchain.config.embedder.google import GoogleAIEmbedderConfig
  5. from embedchain.embedder.base import BaseEmbedder
  6. from embedchain.models import VectorDimensions
  7. class GoogleAIEmbeddingFunction(EmbeddingFunction):
  8. def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None:
  9. super().__init__()
  10. self.config = config or GoogleAIEmbedderConfig()
  11. def __call__(self, input: str) -> Embeddings:
  12. model = self.config.model
  13. title = self.config.title
  14. task_type = self.config.task_type
  15. embeddings = genai.embed_content(model=model, content=input, task_type=task_type, title=title)
  16. return embeddings["embedding"]
  17. class GoogleAIEmbedder(BaseEmbedder):
  18. def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None):
  19. super().__init__(config)
  20. embedding_fn = GoogleAIEmbeddingFunction(config=config)
  21. self.set_embedding_fn(embedding_fn=embedding_fn)
  22. vector_dimension = VectorDimensions.GOOGLE_AI.value
  23. self.set_vector_dimension(vector_dimension=vector_dimension)