base.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from typing import Any, Callable, Optional
  2. from embedchain.config.embedder.base import BaseEmbedderConfig
  3. try:
  4. from chromadb.api.types import Documents, Embeddings
  5. except RuntimeError:
  6. from embedchain.utils import use_pysqlite3
  7. use_pysqlite3()
  8. from chromadb.api.types import Documents, Embeddings
  9. class BaseEmbedder:
  10. """
  11. Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
  12. Embedding functions and vector dimensions are set based on the child class you choose.
  13. To manually overwrite you can use this classes `set_...` methods.
  14. """
  15. def __init__(self, config: Optional[BaseEmbedderConfig] = None):
  16. """
  17. Intialize the embedder class.
  18. :param config: embedder configuration option class, defaults to None
  19. :type config: Optional[BaseEmbedderConfig], optional
  20. """
  21. if config is None:
  22. self.config = BaseEmbedderConfig()
  23. else:
  24. self.config = config
  25. self.vector_dimension: int
  26. def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
  27. """
  28. Set or overwrite the embedding function to be used by the database to store and retrieve documents.
  29. :param embedding_fn: Function to be used to generate embeddings.
  30. :type embedding_fn: Callable[[list[str]], list[str]]
  31. :raises ValueError: Embedding function is not callable.
  32. """
  33. if not hasattr(embedding_fn, "__call__"):
  34. raise ValueError("Embedding function is not a function")
  35. self.embedding_fn = embedding_fn
  36. def set_vector_dimension(self, vector_dimension: int):
  37. """
  38. Set or overwrite the vector dimension size
  39. :param vector_dimension: vector dimension size
  40. :type vector_dimension: int
  41. """
  42. if not isinstance(vector_dimension, int):
  43. raise TypeError("vector dimension must be int")
  44. self.vector_dimension = vector_dimension
  45. @staticmethod
  46. def _langchain_default_concept(embeddings: Any):
  47. """
  48. Langchains default function layout for embeddings.
  49. :param embeddings: Langchain embeddings
  50. :type embeddings: Any
  51. :return: embedding function
  52. :rtype: Callable
  53. """
  54. def embed_function(texts: Documents) -> Embeddings:
  55. return embeddings.embed_documents(texts)
  56. return embed_function