base_embedder.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from typing import Any, Callable, Optional
  2. from embedchain.config.embedder.BaseEmbedderConfig import BaseEmbedderConfig
  3. try:
  4. from chromadb.api.types import Documents, Embeddings
  5. except RuntimeError:
  6. from embedchain.utils import use_pysqlite3
  7. use_pysqlite3()
  8. from chromadb.api.types import Documents, Embeddings
  9. class BaseEmbedder:
  10. """
  11. Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
  12. Embedding functions and vector dimensions are set based on the child class you choose.
  13. To manually overwrite you can use this classes `set_...` methods.
  14. """
  15. def __init__(self, config: Optional[BaseEmbedderConfig] = FileNotFoundError):
  16. if config is None:
  17. self.config = BaseEmbedderConfig()
  18. else:
  19. self.config = config
  20. self.vector_dimension: int
  21. def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
  22. if not hasattr(embedding_fn, "__call__"):
  23. raise ValueError("Embedding function is not a function")
  24. self.embedding_fn = embedding_fn
  25. def set_vector_dimension(self, vector_dimension: int):
  26. self.vector_dimension = vector_dimension
  27. @staticmethod
  28. def _langchain_default_concept(embeddings: Any):
  29. """
  30. Langchains default function layout for embeddings.
  31. """
  32. def embed_function(texts: Documents) -> Embeddings:
  33. return embeddings.embed_documents(texts)
  34. return embed_function