add_config.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import builtins
  2. import logging
  3. from collections.abc import Callable
  4. from importlib import import_module
  5. from typing import Optional
  6. from embedchain.config.base_config import BaseConfig
  7. from embedchain.helpers.json_serializable import register_deserializable
  8. @register_deserializable
  9. class ChunkerConfig(BaseConfig):
  10. """
  11. Config for the chunker used in `add` method
  12. """
  13. def __init__(
  14. self,
  15. chunk_size: Optional[int] = 2000,
  16. chunk_overlap: Optional[int] = 0,
  17. length_function: Optional[Callable[[str], int]] = None,
  18. min_chunk_size: Optional[int] = 0,
  19. ):
  20. self.chunk_size = chunk_size
  21. self.chunk_overlap = chunk_overlap
  22. self.min_chunk_size = min_chunk_size
  23. if self.min_chunk_size >= self.chunk_size:
  24. raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
  25. if self.min_chunk_size < self.chunk_overlap:
  26. logging.warning(
  27. f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501
  28. )
  29. if isinstance(length_function, str):
  30. self.length_function = self.load_func(length_function)
  31. else:
  32. self.length_function = length_function if length_function else len
  33. @staticmethod
  34. def load_func(dotpath: str):
  35. if "." not in dotpath:
  36. return getattr(builtins, dotpath)
  37. else:
  38. module_, func = dotpath.rsplit(".", maxsplit=1)
  39. m = import_module(module_)
  40. return getattr(m, func)
  41. @register_deserializable
  42. class LoaderConfig(BaseConfig):
  43. """
  44. Config for the loader used in `add` method
  45. """
  46. def __init__(self):
  47. pass
  48. @register_deserializable
  49. class AddConfig(BaseConfig):
  50. """
  51. Config for the `add` method.
  52. """
  53. def __init__(
  54. self,
  55. chunker: Optional[ChunkerConfig] = None,
  56. loader: Optional[LoaderConfig] = None,
  57. ):
  58. """
  59. Initializes a configuration class instance for the `add` method.
  60. :param chunker: Chunker config, defaults to None
  61. :type chunker: Optional[ChunkerConfig], optional
  62. :param loader: Loader config, defaults to None
  63. :type loader: Optional[LoaderConfig], optional
  64. """
  65. self.loader = loader
  66. self.chunker = chunker