|
@@ -3,12 +3,20 @@ from typing import Any, Callable, Optional
|
|
|
from embedchain.config.embedder.base import BaseEmbedderConfig
|
|
|
|
|
|
try:
|
|
|
- from chromadb.api.types import Documents, Embeddings
|
|
|
+ from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
|
|
|
except RuntimeError:
|
|
|
from embedchain.utils import use_pysqlite3
|
|
|
|
|
|
use_pysqlite3()
|
|
|
- from chromadb.api.types import Documents, Embeddings
|
|
|
+ from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
|
|
|
+
|
|
|
+
|
|
|
+class EmbeddingFunc(EmbeddingFunction):
|
|
|
+ def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
|
|
|
+ self.embedding_fn = embedding_fn
|
|
|
+
|
|
|
+ def __call__(self, input: Embeddable) -> Embeddings:
|
|
|
+ return self.embedding_fn(input)
|
|
|
|
|
|
|
|
|
class BaseEmbedder:
|
|
@@ -66,7 +74,4 @@ class BaseEmbedder:
|
|
|
:rtype: Callable
|
|
|
"""
|
|
|
|
|
|
- def embed_function(texts: Documents) -> Embeddings:
|
|
|
- return embeddings.embed_documents(texts)
|
|
|
-
|
|
|
- return embed_function
|
|
|
+ return EmbeddingFunc(embeddings.embed_documents)
|