123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- from collections.abc import Callable
- from typing import Any, Optional
- from embedchain.config.embedder.base import BaseEmbedderConfig
- try:
- from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
- except RuntimeError:
- from embedchain.utils.misc import use_pysqlite3
- use_pysqlite3()
- from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
- class EmbeddingFunc(EmbeddingFunction):
- def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
- self.embedding_fn = embedding_fn
- def __call__(self, input: Embeddable) -> Embeddings:
- return self.embedding_fn(input)
- class BaseEmbedder:
- """
- Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
- Embedding functions and vector dimensions are set based on the child class you choose.
- To manually overwrite you can use this classes `set_...` methods.
- """
- def __init__(self, config: Optional[BaseEmbedderConfig] = None):
- """
- Initialize the embedder class.
- :param config: embedder configuration option class, defaults to None
- :type config: Optional[BaseEmbedderConfig], optional
- """
- if config is None:
- self.config = BaseEmbedderConfig()
- else:
- self.config = config
- self.vector_dimension: int
- def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
- """
- Set or overwrite the embedding function to be used by the database to store and retrieve documents.
- :param embedding_fn: Function to be used to generate embeddings.
- :type embedding_fn: Callable[[list[str]], list[str]]
- :raises ValueError: Embedding function is not callable.
- """
- if not hasattr(embedding_fn, "__call__"):
- raise ValueError("Embedding function is not a function")
- self.embedding_fn = embedding_fn
- def set_vector_dimension(self, vector_dimension: int):
- """
- Set or overwrite the vector dimension size
- :param vector_dimension: vector dimension size
- :type vector_dimension: int
- """
- if not isinstance(vector_dimension, int):
- raise TypeError("vector dimension must be int")
- self.vector_dimension = vector_dimension
- @staticmethod
- def _langchain_default_concept(embeddings: Any):
- """
- Langchains default function layout for embeddings.
- :param embeddings: Langchain embeddings
- :type embeddings: Any
- :return: embedding function
- :rtype: Callable
- """
- return EmbeddingFunc(embeddings.embed_documents)
- def to_embeddings(self, data: str, **_):
- """
- Convert data to embeddings
- :param data: data to convert to embeddings
- :type data: str
- :return: embeddings
- :rtype: list[float]
- """
- embeddings = self.embedding_fn([data])
- return embeddings[0]
|