base_embedder.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from typing import Any, Callable, Optional
  2. from embedchain.config.embedder.BaseEmbedderConfig import BaseEmbedderConfig
  3. try:
  4. from chromadb.api.types import Documents, Embeddings
  5. except RuntimeError:
  6. from embedchain.utils import use_pysqlite3
  7. use_pysqlite3()
  8. from chromadb.api.types import Documents, Embeddings
  9. class BaseEmbedder:
  10. """
  11. Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
  12. Embedding functions and vector dimensions are set based on the child class you choose.
  13. To manually overwrite you can use this classes `set_...` methods.
  14. """
  15. def __init__(self, config: Optional[BaseEmbedderConfig] = FileNotFoundError):
  16. if config is None:
  17. self.config = BaseEmbedderConfig()
  18. else:
  19. self.config = config
  20. def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
  21. if not hasattr(embedding_fn, "__call__"):
  22. raise ValueError("Embedding function is not a function")
  23. self.embedding_fn = embedding_fn
  24. def set_vector_dimension(self, vector_dimension: int):
  25. self.vector_dimension = vector_dimension
  26. @staticmethod
  27. def _langchain_default_concept(embeddings: Any):
  28. """
  29. Langchains default function layout for embeddings.
  30. """
  31. def embed_function(texts: Documents) -> Embeddings:
  32. return embeddings.embed_documents(texts)
  33. return embed_function