|
@@ -1,4 +1,4 @@
|
|
|
-from typing import Optional
|
|
|
+from typing import Optional, Union
|
|
|
|
|
|
import google.generativeai as genai
|
|
|
from chromadb import EmbeddingFunction, Embeddings
|
|
@@ -13,12 +13,19 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction):
|
|
|
super().__init__()
|
|
|
self.config = config or GoogleAIEmbedderConfig()
|
|
|
|
|
|
- def __call__(self, input_: str) -> Embeddings:
|
|
|
+ def __call__(self, input: Union[list[str], str]) -> Embeddings:
|
|
|
model = self.config.model
|
|
|
title = self.config.title
|
|
|
task_type = self.config.task_type
|
|
|
- embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
|
|
|
- return embeddings["embedding"]
|
|
|
+ if isinstance(input, str):
|
|
|
+ input_ = [input]
|
|
|
+ else:
|
|
|
+ input_ = input
|
|
|
+ data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
|
|
|
+ embeddings = data["embedding"]
|
|
|
+ if isinstance(input_, str):
|
|
|
+ embeddings = [embeddings]
|
|
|
+ return embeddings
|
|
|
|
|
|
|
|
|
class GoogleAIEmbedder(BaseEmbedder):
|