Explorar o código

[Bugfix] fix google ai embedding function (#1195)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel hai 1 ano
pai
achega
0b5b12575a
Modificáronse 1 ficheiros con 11 adicións e 4 borrados
  1. 11 4
      embedchain/embedder/google.py

+ 11 - 4
embedchain/embedder/google.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
 
 import google.generativeai as genai
 from chromadb import EmbeddingFunction, Embeddings
@@ -13,12 +13,19 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction):
         super().__init__()
         self.config = config or GoogleAIEmbedderConfig()
 
-    def __call__(self, input_: str) -> Embeddings:
+    def __call__(self, input: Union[list[str], str]) -> Embeddings:
         model = self.config.model
         title = self.config.title
         task_type = self.config.task_type
-        embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
-        return embeddings["embedding"]
+        if isinstance(input, str):
+            input_ = [input]
+        else:
+            input_ = input
+        data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
+        embeddings = data["embedding"]
+        if isinstance(input_, str):
+            embeddings = [embeddings]
+        return embeddings
 
 
 class GoogleAIEmbedder(BaseEmbedder):