data_formatter.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from embedchain.chunkers.docs_site import DocsSiteChunker
  2. from embedchain.chunkers.docx_file import DocxFileChunker
  3. from embedchain.chunkers.notion import NotionChunker
  4. from embedchain.chunkers.pdf_file import PdfFileChunker
  5. from embedchain.chunkers.qna_pair import QnaPairChunker
  6. from embedchain.chunkers.text import TextChunker
  7. from embedchain.chunkers.web_page import WebPageChunker
  8. from embedchain.chunkers.youtube_video import YoutubeVideoChunker
  9. from embedchain.config import AddConfig
  10. from embedchain.loaders.docs_site_loader import DocsSiteLoader
  11. from embedchain.loaders.docx_file import DocxFileLoader
  12. from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
  13. from embedchain.loaders.local_text import LocalTextLoader
  14. from embedchain.loaders.pdf_file import PdfFileLoader
  15. from embedchain.loaders.sitemap import SitemapLoader
  16. from embedchain.loaders.web_page import WebPageLoader
  17. from embedchain.loaders.youtube_video import YoutubeVideoLoader
  18. class DataFormatter:
  19. """
  20. DataFormatter is an internal utility class which abstracts the mapping for
  21. loaders and chunkers to the data_type entered by the user in their
  22. .add or .add_local method call
  23. """
  24. def __init__(self, data_type: str, config: AddConfig):
  25. self.loader = self._get_loader(data_type, config.loader)
  26. self.chunker = self._get_chunker(data_type, config.chunker)
  27. def _get_loader(self, data_type, config):
  28. """
  29. Returns the appropriate data loader for the given data type.
  30. :param data_type: The type of the data to load.
  31. :return: The loader for the given data type.
  32. :raises ValueError: If an unsupported data type is provided.
  33. """
  34. loaders = {
  35. "youtube_video": YoutubeVideoLoader,
  36. "pdf_file": PdfFileLoader,
  37. "web_page": WebPageLoader,
  38. "qna_pair": LocalQnaPairLoader,
  39. "text": LocalTextLoader,
  40. "docx": DocxFileLoader,
  41. "sitemap": SitemapLoader,
  42. "docs_site": DocsSiteLoader,
  43. }
  44. lazy_loaders = ("notion",)
  45. if data_type in loaders:
  46. loader_class = loaders[data_type]
  47. loader = loader_class()
  48. return loader
  49. elif data_type in lazy_loaders:
  50. if data_type == "notion":
  51. from embedchain.loaders.notion import NotionLoader
  52. return NotionLoader()
  53. else:
  54. raise ValueError(f"Unsupported data type: {data_type}")
  55. else:
  56. raise ValueError(f"Unsupported data type: {data_type}")
  57. def _get_chunker(self, data_type, config):
  58. """
  59. Returns the appropriate chunker for the given data type.
  60. :param data_type: The type of the data to chunk.
  61. :return: The chunker for the given data type.
  62. :raises ValueError: If an unsupported data type is provided.
  63. """
  64. chunker_classes = {
  65. "youtube_video": YoutubeVideoChunker,
  66. "pdf_file": PdfFileChunker,
  67. "web_page": WebPageChunker,
  68. "qna_pair": QnaPairChunker,
  69. "text": TextChunker,
  70. "docx": DocxFileChunker,
  71. "sitemap": WebPageChunker,
  72. "docs_site": DocsSiteChunker,
  73. "notion": NotionChunker,
  74. }
  75. if data_type in chunker_classes:
  76. chunker_class = chunker_classes[data_type]
  77. chunker = chunker_class(config)
  78. chunker.set_data_type(data_type)
  79. return chunker
  80. else:
  81. raise ValueError(f"Unsupported data type: {data_type}")