data_formatter.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from embedchain.chunkers.base_chunker import BaseChunker
  2. from embedchain.chunkers.docs_site import DocsSiteChunker
  3. from embedchain.chunkers.docx_file import DocxFileChunker
  4. from embedchain.chunkers.images import ImagesChunker
  5. from embedchain.chunkers.json import JSONChunker
  6. from embedchain.chunkers.mdx import MdxChunker
  7. from embedchain.chunkers.notion import NotionChunker
  8. from embedchain.chunkers.pdf_file import PdfFileChunker
  9. from embedchain.chunkers.unstructured_file import UnstructuredFileChunker
  10. from embedchain.chunkers.qna_pair import QnaPairChunker
  11. from embedchain.chunkers.sitemap import SitemapChunker
  12. from embedchain.chunkers.table import TableChunker
  13. from embedchain.chunkers.text import TextChunker
  14. from embedchain.chunkers.web_page import WebPageChunker
  15. from embedchain.chunkers.xml import XmlChunker
  16. from embedchain.chunkers.youtube_video import YoutubeVideoChunker
  17. from embedchain.config import AddConfig
  18. from embedchain.config.add_config import ChunkerConfig, LoaderConfig
  19. from embedchain.helper.json_serializable import JSONSerializable
  20. from embedchain.loaders.base_loader import BaseLoader
  21. from embedchain.loaders.csv import CsvLoader
  22. from embedchain.loaders.docs_site_loader import DocsSiteLoader
  23. from embedchain.loaders.docx_file import DocxFileLoader
  24. from embedchain.loaders.images import ImagesLoader
  25. from embedchain.loaders.json import JSONLoader
  26. from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
  27. from embedchain.loaders.local_text import LocalTextLoader
  28. from embedchain.loaders.mdx import MdxLoader
  29. from embedchain.loaders.pdf_file import PdfFileLoader
  30. from embedchain.loaders.sitemap import SitemapLoader
  31. from embedchain.loaders.web_page import WebPageLoader
  32. from embedchain.loaders.xml import XmlLoader
  33. from embedchain.loaders.youtube_video import YoutubeVideoLoader
  34. from embedchain.loaders.unstructured_file import UnstructuredLoader
  35. from embedchain.models.data_type import DataType
  36. class DataFormatter(JSONSerializable):
  37. """
  38. DataFormatter is an internal utility class which abstracts the mapping for
  39. loaders and chunkers to the data_type entered by the user in their
  40. .add or .add_local method call
  41. """
  42. def __init__(self, data_type: DataType, config: AddConfig):
  43. """
  44. Initialize a dataformatter, set data type and chunker based on datatype.
  45. :param data_type: The type of the data to load and chunk.
  46. :type data_type: DataType
  47. :param config: AddConfig instance with nested loader and chunker config attributes.
  48. :type config: AddConfig
  49. """
  50. self.loader = self._get_loader(data_type=data_type, config=config.loader)
  51. self.chunker = self._get_chunker(data_type=data_type, config=config.chunker)
  52. def _get_loader(self, data_type: DataType, config: LoaderConfig) -> BaseLoader:
  53. """
  54. Returns the appropriate data loader for the given data type.
  55. :param data_type: The type of the data to load.
  56. :type data_type: DataType
  57. :param config: Config to initialize the loader with.
  58. :type config: LoaderConfig
  59. :raises ValueError: If an unsupported data type is provided.
  60. :return: The loader for the given data type.
  61. :rtype: BaseLoader
  62. """
  63. loaders = {
  64. DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
  65. DataType.PDF_FILE: PdfFileLoader,
  66. DataType.WEB_PAGE: WebPageLoader,
  67. DataType.QNA_PAIR: LocalQnaPairLoader,
  68. DataType.TEXT: LocalTextLoader,
  69. DataType.DOCX: DocxFileLoader,
  70. DataType.SITEMAP: SitemapLoader,
  71. DataType.XML: XmlLoader,
  72. DataType.DOCS_SITE: DocsSiteLoader,
  73. DataType.CSV: CsvLoader,
  74. DataType.MDX: MdxLoader,
  75. DataType.IMAGES: ImagesLoader,
  76. DataType.UNSTRUCTURED: UnstructuredLoader,
  77. DataType.JSON: JSONLoader,
  78. }
  79. lazy_loaders = {DataType.NOTION}
  80. if data_type in loaders:
  81. loader_class: type = loaders[data_type]
  82. loader: BaseLoader = loader_class()
  83. return loader
  84. elif data_type in lazy_loaders:
  85. if data_type == DataType.NOTION:
  86. from embedchain.loaders.notion import NotionLoader
  87. return NotionLoader()
  88. else:
  89. raise ValueError(f"Unsupported data type: {data_type}")
  90. else:
  91. raise ValueError(f"Unsupported data type: {data_type}")
  92. def _get_chunker(self, data_type: DataType, config: ChunkerConfig) -> BaseChunker:
  93. """Returns the appropriate chunker for the given data type.
  94. :param data_type: The type of the data to chunk.
  95. :type data_type: DataType
  96. :param config: Config to initialize the chunker with.
  97. :type config: ChunkerConfig
  98. :raises ValueError: If an unsupported data type is provided.
  99. :return: The chunker for the given data type.
  100. :rtype: BaseChunker
  101. """
  102. chunker_classes = {
  103. DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
  104. DataType.PDF_FILE: PdfFileChunker,
  105. DataType.WEB_PAGE: WebPageChunker,
  106. DataType.QNA_PAIR: QnaPairChunker,
  107. DataType.TEXT: TextChunker,
  108. DataType.DOCX: DocxFileChunker,
  109. DataType.DOCS_SITE: DocsSiteChunker,
  110. DataType.SITEMAP: SitemapChunker,
  111. DataType.NOTION: NotionChunker,
  112. DataType.CSV: TableChunker,
  113. DataType.MDX: MdxChunker,
  114. DataType.IMAGES: ImagesChunker,
  115. DataType.XML: XmlChunker,
  116. DataType.UNSTRUCTURED: UnstructuredFileChunker,
  117. DataType.JSON: JSONChunker,
  118. }
  119. if data_type in chunker_classes:
  120. chunker_class: type = chunker_classes[data_type]
  121. chunker: BaseChunker = chunker_class(config)
  122. chunker.set_data_type(data_type)
  123. return chunker
  124. else:
  125. raise ValueError(f"Unsupported data type: {data_type}")