data_formatter.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from embedchain.chunkers.docs_site import DocsSiteChunker
  2. from embedchain.chunkers.docx_file import DocxFileChunker
  3. from embedchain.chunkers.notion import NotionChunker
  4. from embedchain.chunkers.pdf_file import PdfFileChunker
  5. from embedchain.chunkers.qna_pair import QnaPairChunker
  6. from embedchain.chunkers.table import TableChunker
  7. from embedchain.chunkers.text import TextChunker
  8. from embedchain.chunkers.web_page import WebPageChunker
  9. from embedchain.chunkers.youtube_video import YoutubeVideoChunker
  10. from embedchain.config import AddConfig
  11. from embedchain.helper_classes.json_serializable import JSONSerializable
  12. from embedchain.loaders.csv import CsvLoader
  13. from embedchain.loaders.docs_site_loader import DocsSiteLoader
  14. from embedchain.loaders.docx_file import DocxFileLoader
  15. from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
  16. from embedchain.loaders.local_text import LocalTextLoader
  17. from embedchain.loaders.pdf_file import PdfFileLoader
  18. from embedchain.loaders.sitemap import SitemapLoader
  19. from embedchain.loaders.web_page import WebPageLoader
  20. from embedchain.loaders.youtube_video import YoutubeVideoLoader
  21. from embedchain.models.data_type import DataType
  22. class DataFormatter(JSONSerializable):
  23. """
  24. DataFormatter is an internal utility class which abstracts the mapping for
  25. loaders and chunkers to the data_type entered by the user in their
  26. .add or .add_local method call
  27. """
  28. def __init__(self, data_type: DataType, config: AddConfig):
  29. self.loader = self._get_loader(data_type, config.loader)
  30. self.chunker = self._get_chunker(data_type, config.chunker)
  31. def _get_loader(self, data_type: DataType, config):
  32. """
  33. Returns the appropriate data loader for the given data type.
  34. :param data_type: The type of the data to load.
  35. :return: The loader for the given data type.
  36. :raises ValueError: If an unsupported data type is provided.
  37. """
  38. loaders = {
  39. DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
  40. DataType.PDF_FILE: PdfFileLoader,
  41. DataType.WEB_PAGE: WebPageLoader,
  42. DataType.QNA_PAIR: LocalQnaPairLoader,
  43. DataType.TEXT: LocalTextLoader,
  44. DataType.DOCX: DocxFileLoader,
  45. DataType.SITEMAP: SitemapLoader,
  46. DataType.DOCS_SITE: DocsSiteLoader,
  47. DataType.CSV: CsvLoader,
  48. }
  49. lazy_loaders = {DataType.NOTION}
  50. if data_type in loaders:
  51. loader_class = loaders[data_type]
  52. loader = loader_class()
  53. return loader
  54. elif data_type in lazy_loaders:
  55. if data_type == DataType.NOTION:
  56. from embedchain.loaders.notion import NotionLoader
  57. return NotionLoader()
  58. else:
  59. raise ValueError(f"Unsupported data type: {data_type}")
  60. else:
  61. raise ValueError(f"Unsupported data type: {data_type}")
  62. def _get_chunker(self, data_type: DataType, config):
  63. """
  64. Returns the appropriate chunker for the given data type.
  65. :param data_type: The type of the data to chunk.
  66. :return: The chunker for the given data type.
  67. :raises ValueError: If an unsupported data type is provided.
  68. """
  69. chunker_classes = {
  70. DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
  71. DataType.PDF_FILE: PdfFileChunker,
  72. DataType.WEB_PAGE: WebPageChunker,
  73. DataType.QNA_PAIR: QnaPairChunker,
  74. DataType.TEXT: TextChunker,
  75. DataType.DOCX: DocxFileChunker,
  76. DataType.WEB_PAGE: WebPageChunker,
  77. DataType.DOCS_SITE: DocsSiteChunker,
  78. DataType.NOTION: NotionChunker,
  79. DataType.CSV: TableChunker,
  80. }
  81. if data_type in chunker_classes:
  82. chunker_class = chunker_classes[data_type]
  83. chunker = chunker_class(config)
  84. chunker.set_data_type(data_type)
  85. return chunker
  86. else:
  87. raise ValueError(f"Unsupported data type: {data_type}")