data_formatter.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from embedchain.chunkers.docs_site import DocsSiteChunker
  2. from embedchain.chunkers.docx_file import DocxFileChunker
  3. from embedchain.chunkers.notion import NotionChunker
  4. from embedchain.chunkers.pdf_file import PdfFileChunker
  5. from embedchain.chunkers.qna_pair import QnaPairChunker
  6. from embedchain.chunkers.text import TextChunker
  7. from embedchain.chunkers.web_page import WebPageChunker
  8. from embedchain.chunkers.youtube_video import YoutubeVideoChunker
  9. from embedchain.config import AddConfig
  10. from embedchain.loaders.docs_site_loader import DocsSiteLoader
  11. from embedchain.loaders.docx_file import DocxFileLoader
  12. from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
  13. from embedchain.loaders.local_text import LocalTextLoader
  14. from embedchain.loaders.pdf_file import PdfFileLoader
  15. from embedchain.loaders.sitemap import SitemapLoader
  16. from embedchain.loaders.web_page import WebPageLoader
  17. from embedchain.loaders.youtube_video import YoutubeVideoLoader
  18. from embedchain.models.data_type import DataType
  19. class DataFormatter:
  20. """
  21. DataFormatter is an internal utility class which abstracts the mapping for
  22. loaders and chunkers to the data_type entered by the user in their
  23. .add or .add_local method call
  24. """
  25. def __init__(self, data_type: DataType, config: AddConfig):
  26. self.loader = self._get_loader(data_type, config.loader)
  27. self.chunker = self._get_chunker(data_type, config.chunker)
  28. def _get_loader(self, data_type: DataType, config):
  29. """
  30. Returns the appropriate data loader for the given data type.
  31. :param data_type: The type of the data to load.
  32. :return: The loader for the given data type.
  33. :raises ValueError: If an unsupported data type is provided.
  34. """
  35. loaders = {
  36. DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
  37. DataType.PDF_FILE: PdfFileLoader,
  38. DataType.WEB_PAGE: WebPageLoader,
  39. DataType.QNA_PAIR: LocalQnaPairLoader,
  40. DataType.TEXT: LocalTextLoader,
  41. DataType.DOCX: DocxFileLoader,
  42. DataType.SITEMAP: SitemapLoader,
  43. DataType.DOCS_SITE: DocsSiteLoader,
  44. }
  45. lazy_loaders = {DataType.NOTION}
  46. if data_type in loaders:
  47. loader_class = loaders[data_type]
  48. loader = loader_class()
  49. return loader
  50. elif data_type in lazy_loaders:
  51. if data_type == DataType.NOTION:
  52. from embedchain.loaders.notion import NotionLoader
  53. return NotionLoader()
  54. else:
  55. raise ValueError(f"Unsupported data type: {data_type}")
  56. else:
  57. raise ValueError(f"Unsupported data type: {data_type}")
  58. def _get_chunker(self, data_type: DataType, config):
  59. """
  60. Returns the appropriate chunker for the given data type.
  61. :param data_type: The type of the data to chunk.
  62. :return: The chunker for the given data type.
  63. :raises ValueError: If an unsupported data type is provided.
  64. """
  65. chunker_classes = {
  66. DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
  67. DataType.PDF_FILE: PdfFileChunker,
  68. DataType.WEB_PAGE: WebPageChunker,
  69. DataType.QNA_PAIR: QnaPairChunker,
  70. DataType.TEXT: TextChunker,
  71. DataType.DOCX: DocxFileChunker,
  72. DataType.WEB_PAGE: WebPageChunker,
  73. DataType.DOCS_SITE: DocsSiteChunker,
  74. DataType.NOTION: NotionChunker,
  75. }
  76. if data_type in chunker_classes:
  77. chunker_class = chunker_classes[data_type]
  78. chunker = chunker_class(config)
  79. chunker.set_data_type(data_type)
  80. return chunker
  81. else:
  82. raise ValueError(f"Unsupported data type: {data_type}")