base.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from typing import Any, Callable, Optional
  2. from embedchain.config.embedder.base import BaseEmbedderConfig
  3. try:
  4. from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
  5. except RuntimeError:
  6. from embedchain.utils import use_pysqlite3
  7. use_pysqlite3()
  8. from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
  9. class EmbeddingFunc(EmbeddingFunction):
  10. def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
  11. self.embedding_fn = embedding_fn
  12. def __call__(self, input: Embeddable) -> Embeddings:
  13. return self.embedding_fn(input)
  14. class BaseEmbedder:
  15. """
  16. Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
  17. Embedding functions and vector dimensions are set based on the child class you choose.
  18. To manually overwrite you can use this classes `set_...` methods.
  19. """
  20. def __init__(self, config: Optional[BaseEmbedderConfig] = None):
  21. """
  22. Intialize the embedder class.
  23. :param config: embedder configuration option class, defaults to None
  24. :type config: Optional[BaseEmbedderConfig], optional
  25. """
  26. if config is None:
  27. self.config = BaseEmbedderConfig()
  28. else:
  29. self.config = config
  30. self.vector_dimension: int
  31. def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
  32. """
  33. Set or overwrite the embedding function to be used by the database to store and retrieve documents.
  34. :param embedding_fn: Function to be used to generate embeddings.
  35. :type embedding_fn: Callable[[list[str]], list[str]]
  36. :raises ValueError: Embedding function is not callable.
  37. """
  38. if not hasattr(embedding_fn, "__call__"):
  39. raise ValueError("Embedding function is not a function")
  40. self.embedding_fn = embedding_fn
  41. def set_vector_dimension(self, vector_dimension: int):
  42. """
  43. Set or overwrite the vector dimension size
  44. :param vector_dimension: vector dimension size
  45. :type vector_dimension: int
  46. """
  47. if not isinstance(vector_dimension, int):
  48. raise TypeError("vector dimension must be int")
  49. self.vector_dimension = vector_dimension
  50. @staticmethod
  51. def _langchain_default_concept(embeddings: Any):
  52. """
  53. Langchains default function layout for embeddings.
  54. :param embeddings: Langchain embeddings
  55. :type embeddings: Any
  56. :return: embedding function
  57. :rtype: Callable
  58. """
  59. return EmbeddingFunc(embeddings.embed_documents)
  60. def to_embeddings(self, data: str, **_):
  61. """
  62. Convert data to embeddings
  63. :param data: data to convert to embeddings
  64. :type data: str
  65. :return: embeddings
  66. :rtype: list[float]
  67. """
  68. embeddings = self.embedding_fn([data])
  69. return embeddings[0]