google.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import Optional, Union
  2. import google.generativeai as genai
  3. from chromadb import EmbeddingFunction, Embeddings
  4. from embedchain.config.embedder.google import GoogleAIEmbedderConfig
  5. from embedchain.embedder.base import BaseEmbedder
  6. from embedchain.models import VectorDimensions
  7. class GoogleAIEmbeddingFunction(EmbeddingFunction):
  8. def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None:
  9. super().__init__()
  10. self.config = config or GoogleAIEmbedderConfig()
  11. def __call__(self, input: Union[list[str], str]) -> Embeddings:
  12. model = self.config.model
  13. title = self.config.title
  14. task_type = self.config.task_type
  15. if isinstance(input, str):
  16. input_ = [input]
  17. else:
  18. input_ = input
  19. data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
  20. embeddings = data["embedding"]
  21. if isinstance(input_, str):
  22. embeddings = [embeddings]
  23. return embeddings
  24. class GoogleAIEmbedder(BaseEmbedder):
  25. def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None):
  26. super().__init__(config)
  27. embedding_fn = GoogleAIEmbeddingFunction(config=config)
  28. self.set_embedding_fn(embedding_fn=embedding_fn)
  29. vector_dimension = self.config.vector_dimension or VectorDimensions.GOOGLE_AI.value
  30. self.set_vector_dimension(vector_dimension=vector_dimension)