소스 검색

[Bugfix] fix google ai embedding function (#1195)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 년 전
부모
커밋
0b5b12575a
1개의 변경된 파일11개의 추가작업 그리고 4개의 파일을 삭제
  1. 11 4
      embedchain/embedder/google.py

+ 11 - 4
embedchain/embedder/google.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
 
 import google.generativeai as genai
 from chromadb import EmbeddingFunction, Embeddings
@@ -13,12 +13,19 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction):
         super().__init__()
         self.config = config or GoogleAIEmbedderConfig()
 
-    def __call__(self, input_: str) -> Embeddings:
+    def __call__(self, input: Union[list[str], str]) -> Embeddings:
         model = self.config.model
         title = self.config.title
         task_type = self.config.task_type
-        embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
-        return embeddings["embedding"]
+        if isinstance(input, str):
+            input_ = [input]
+        else:
+            input_ = input
+        data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
+        embeddings = data["embedding"]
+        if isinstance(input_, str):
+            embeddings = [embeddings]
+        return embeddings
 
 
 class GoogleAIEmbedder(BaseEmbedder):