data_formatter.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from embedchain.chunkers.docs_site import DocsSiteChunker
  2. from embedchain.chunkers.docx_file import DocxFileChunker
  3. from embedchain.chunkers.notion import NotionChunker
  4. from embedchain.chunkers.pdf_file import PdfFileChunker
  5. from embedchain.chunkers.qna_pair import QnaPairChunker
  6. from embedchain.chunkers.text import TextChunker
  7. from embedchain.chunkers.web_page import WebPageChunker
  8. from embedchain.chunkers.youtube_video import YoutubeVideoChunker
  9. from embedchain.config import AddConfig
  10. from embedchain.loaders.docs_site_loader import DocsSiteLoader
  11. from embedchain.loaders.docx_file import DocxFileLoader
  12. from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
  13. from embedchain.loaders.local_text import LocalTextLoader
  14. from embedchain.loaders.pdf_file import PdfFileLoader
  15. from embedchain.loaders.sitemap import SitemapLoader
  16. from embedchain.loaders.web_page import WebPageLoader
  17. from embedchain.loaders.youtube_video import YoutubeVideoLoader
  18. class DataFormatter:
  19. """
  20. DataFormatter is an internal utility class which abstracts the mapping for
  21. loaders and chunkers to the data_type entered by the user in their
  22. .add or .add_local method call
  23. """
  24. def __init__(self, data_type: str, config: AddConfig):
  25. self.loader = self._get_loader(data_type, config.loader)
  26. self.chunker = self._get_chunker(data_type, config.chunker)
  27. def _get_loader(self, data_type, config):
  28. """
  29. Returns the appropriate data loader for the given data type.
  30. :param data_type: The type of the data to load.
  31. :return: The loader for the given data type.
  32. :raises ValueError: If an unsupported data type is provided.
  33. """
  34. loaders = {
  35. "youtube_video": YoutubeVideoLoader(),
  36. "pdf_file": PdfFileLoader(),
  37. "web_page": WebPageLoader(),
  38. "qna_pair": LocalQnaPairLoader(),
  39. "text": LocalTextLoader(),
  40. "docx": DocxFileLoader(),
  41. "sitemap": SitemapLoader(),
  42. "docs_site": DocsSiteLoader(),
  43. }
  44. lazy_loaders = ("notion",)
  45. if data_type in loaders:
  46. return loaders[data_type]
  47. elif data_type in lazy_loaders:
  48. if data_type == "notion":
  49. from embedchain.loaders.notion import NotionLoader
  50. return NotionLoader()
  51. else:
  52. raise ValueError(f"Unsupported data type: {data_type}")
  53. else:
  54. raise ValueError(f"Unsupported data type: {data_type}")
  55. def _get_chunker(self, data_type, config):
  56. """
  57. Returns the appropriate chunker for the given data type.
  58. :param data_type: The type of the data to chunk.
  59. :return: The chunker for the given data type.
  60. :raises ValueError: If an unsupported data type is provided.
  61. """
  62. chunker_classes = {
  63. "youtube_video": YoutubeVideoChunker,
  64. "pdf_file": PdfFileChunker,
  65. "web_page": WebPageChunker,
  66. "qna_pair": QnaPairChunker,
  67. "text": TextChunker,
  68. "docx": DocxFileChunker,
  69. "sitemap": WebPageChunker,
  70. "docs_site": DocsSiteChunker,
  71. "notion": NotionChunker,
  72. }
  73. if data_type in chunker_classes:
  74. chunker_class = chunker_classes[data_type]
  75. chunker = chunker_class(config)
  76. chunker.set_data_type(data_type)
  77. return chunker
  78. else:
  79. raise ValueError(f"Unsupported data type: {data_type}")