base.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from collections.abc import Callable
  2. from typing import Any, Optional
  3. from embedchain.config.embedder.base import BaseEmbedderConfig
  4. try:
  5. from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
  6. except RuntimeError:
  7. from embedchain.utils.misc import use_pysqlite3
  8. use_pysqlite3()
  9. from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
  10. class EmbeddingFunc(EmbeddingFunction):
  11. def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
  12. self.embedding_fn = embedding_fn
  13. def __call__(self, input: Embeddable) -> Embeddings:
  14. return self.embedding_fn(input)
  15. class BaseEmbedder:
  16. """
  17. Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
  18. Embedding functions and vector dimensions are set based on the child class you choose.
  19. To manually overwrite you can use this classes `set_...` methods.
  20. """
  21. def __init__(self, config: Optional[BaseEmbedderConfig] = None):
  22. """
  23. Initialize the embedder class.
  24. :param config: embedder configuration option class, defaults to None
  25. :type config: Optional[BaseEmbedderConfig], optional
  26. """
  27. if config is None:
  28. self.config = BaseEmbedderConfig()
  29. else:
  30. self.config = config
  31. self.vector_dimension: int
  32. def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
  33. """
  34. Set or overwrite the embedding function to be used by the database to store and retrieve documents.
  35. :param embedding_fn: Function to be used to generate embeddings.
  36. :type embedding_fn: Callable[[list[str]], list[str]]
  37. :raises ValueError: Embedding function is not callable.
  38. """
  39. if not hasattr(embedding_fn, "__call__"):
  40. raise ValueError("Embedding function is not a function")
  41. self.embedding_fn = embedding_fn
  42. def set_vector_dimension(self, vector_dimension: int):
  43. """
  44. Set or overwrite the vector dimension size
  45. :param vector_dimension: vector dimension size
  46. :type vector_dimension: int
  47. """
  48. if not isinstance(vector_dimension, int):
  49. raise TypeError("vector dimension must be int")
  50. self.vector_dimension = vector_dimension
  51. @staticmethod
  52. def _langchain_default_concept(embeddings: Any):
  53. """
  54. Langchains default function layout for embeddings.
  55. :param embeddings: Langchain embeddings
  56. :type embeddings: Any
  57. :return: embedding function
  58. :rtype: Callable
  59. """
  60. return EmbeddingFunc(embeddings.embed_documents)
  61. def to_embeddings(self, data: str, **_):
  62. """
  63. Convert data to embeddings
  64. :param data: data to convert to embeddings
  65. :type data: str
  66. :return: embeddings
  67. :rtype: list[float]
  68. """
  69. embeddings = self.embedding_fn([data])
  70. return embeddings[0]