data_formatter.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from embedchain.chunkers.docs_site import DocsSiteChunker
  2. from embedchain.chunkers.docx_file import DocxFileChunker
  3. from embedchain.chunkers.notion import NotionChunker
  4. from embedchain.chunkers.pdf_file import PdfFileChunker
  5. from embedchain.chunkers.qna_pair import QnaPairChunker
  6. from embedchain.chunkers.text import TextChunker
  7. from embedchain.chunkers.web_page import WebPageChunker
  8. from embedchain.chunkers.youtube_video import YoutubeVideoChunker
  9. from embedchain.config import AddConfig
  10. from embedchain.helper_classes.json_serializable import JSONSerializable
  11. from embedchain.loaders.docs_site_loader import DocsSiteLoader
  12. from embedchain.loaders.docx_file import DocxFileLoader
  13. from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
  14. from embedchain.loaders.local_text import LocalTextLoader
  15. from embedchain.loaders.pdf_file import PdfFileLoader
  16. from embedchain.loaders.sitemap import SitemapLoader
  17. from embedchain.loaders.web_page import WebPageLoader
  18. from embedchain.loaders.youtube_video import YoutubeVideoLoader
  19. from embedchain.models.data_type import DataType
  20. class DataFormatter(JSONSerializable):
  21. """
  22. DataFormatter is an internal utility class which abstracts the mapping for
  23. loaders and chunkers to the data_type entered by the user in their
  24. .add or .add_local method call
  25. """
  26. def __init__(self, data_type: DataType, config: AddConfig):
  27. self.loader = self._get_loader(data_type, config.loader)
  28. self.chunker = self._get_chunker(data_type, config.chunker)
  29. def _get_loader(self, data_type: DataType, config):
  30. """
  31. Returns the appropriate data loader for the given data type.
  32. :param data_type: The type of the data to load.
  33. :return: The loader for the given data type.
  34. :raises ValueError: If an unsupported data type is provided.
  35. """
  36. loaders = {
  37. DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
  38. DataType.PDF_FILE: PdfFileLoader,
  39. DataType.WEB_PAGE: WebPageLoader,
  40. DataType.QNA_PAIR: LocalQnaPairLoader,
  41. DataType.TEXT: LocalTextLoader,
  42. DataType.DOCX: DocxFileLoader,
  43. DataType.SITEMAP: SitemapLoader,
  44. DataType.DOCS_SITE: DocsSiteLoader,
  45. }
  46. lazy_loaders = {DataType.NOTION}
  47. if data_type in loaders:
  48. loader_class = loaders[data_type]
  49. loader = loader_class()
  50. return loader
  51. elif data_type in lazy_loaders:
  52. if data_type == DataType.NOTION:
  53. from embedchain.loaders.notion import NotionLoader
  54. return NotionLoader()
  55. else:
  56. raise ValueError(f"Unsupported data type: {data_type}")
  57. else:
  58. raise ValueError(f"Unsupported data type: {data_type}")
  59. def _get_chunker(self, data_type: DataType, config):
  60. """
  61. Returns the appropriate chunker for the given data type.
  62. :param data_type: The type of the data to chunk.
  63. :return: The chunker for the given data type.
  64. :raises ValueError: If an unsupported data type is provided.
  65. """
  66. chunker_classes = {
  67. DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
  68. DataType.PDF_FILE: PdfFileChunker,
  69. DataType.WEB_PAGE: WebPageChunker,
  70. DataType.QNA_PAIR: QnaPairChunker,
  71. DataType.TEXT: TextChunker,
  72. DataType.DOCX: DocxFileChunker,
  73. DataType.WEB_PAGE: WebPageChunker,
  74. DataType.DOCS_SITE: DocsSiteChunker,
  75. DataType.NOTION: NotionChunker,
  76. }
  77. if data_type in chunker_classes:
  78. chunker_class = chunker_classes[data_type]
  79. chunker = chunker_class(config)
  80. chunker.set_data_type(data_type)
  81. return chunker
  82. else:
  83. raise ValueError(f"Unsupported data type: {data_type}")