Browse Source

Update embedding_fn signature to newest chroma db's (#969)

Sidharth Mohanty 1 year ago
parent
commit
85f3ac428b
1 changed files with 11 additions and 6 deletions
  1. 11 6
      embedchain/embedder/base.py

+ 11 - 6
embedchain/embedder/base.py

@@ -3,12 +3,20 @@ from typing import Any, Callable, Optional
 from embedchain.config.embedder.base import BaseEmbedderConfig
 
 try:
-    from chromadb.api.types import Documents, Embeddings
+    from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
 except RuntimeError:
     from embedchain.utils import use_pysqlite3
 
     use_pysqlite3()
-    from chromadb.api.types import Documents, Embeddings
+    from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
+
+
+class EmbeddingFunc(EmbeddingFunction):
+    def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
+        self.embedding_fn = embedding_fn
+
+    def __call__(self, input: Embeddable) -> Embeddings:
+        return self.embedding_fn(input)
 
 
 class BaseEmbedder:
@@ -66,7 +74,4 @@ class BaseEmbedder:
         :rtype: Callable
         """
 
-        def embed_function(texts: Documents) -> Embeddings:
-            return embeddings.embed_documents(texts)
-
-        return embed_function
+        return EmbeddingFunc(embeddings.embed_documents)