data_formatter.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from embedchain.config import AddConfig
  2. from embedchain.loaders.youtube_video import YoutubeVideoLoader
  3. from embedchain.loaders.pdf_file import PdfFileLoader
  4. from embedchain.loaders.web_page import WebPageLoader
  5. from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
  6. from embedchain.loaders.local_text import LocalTextLoader
  7. from embedchain.loaders.docx_file import DocxFileLoader
  8. from embedchain.chunkers.youtube_video import YoutubeVideoChunker
  9. from embedchain.chunkers.pdf_file import PdfFileChunker
  10. from embedchain.chunkers.web_page import WebPageChunker
  11. from embedchain.chunkers.qna_pair import QnaPairChunker
  12. from embedchain.chunkers.text import TextChunker
  13. from embedchain.chunkers.docx_file import DocxFileChunker
  14. class DataFormatter:
  15. """
  16. DataFormatter is an internal utility class which abstracts the mapping for
  17. loaders and chunkers to the data_type entered by the user in their
  18. .add or .add_local method call
  19. """
  20. def __init__(self, data_type: str, config: AddConfig):
  21. self.loader = self._get_loader(data_type, config.loader)
  22. self.chunker = self._get_chunker(data_type, config.chunker)
  23. def _get_loader(self, data_type, config):
  24. """
  25. Returns the appropriate data loader for the given data type.
  26. :param data_type: The type of the data to load.
  27. :return: The loader for the given data type.
  28. :raises ValueError: If an unsupported data type is provided.
  29. """
  30. loaders = {
  31. 'youtube_video': YoutubeVideoLoader(),
  32. 'pdf_file': PdfFileLoader(),
  33. 'web_page': WebPageLoader(),
  34. 'qna_pair': LocalQnaPairLoader(),
  35. 'text': LocalTextLoader(),
  36. 'docx': DocxFileLoader(),
  37. }
  38. if data_type in loaders:
  39. return loaders[data_type]
  40. else:
  41. raise ValueError(f"Unsupported data type: {data_type}")
  42. def _get_chunker(self, data_type, config):
  43. """
  44. Returns the appropriate chunker for the given data type.
  45. :param data_type: The type of the data to chunk.
  46. :return: The chunker for the given data type.
  47. :raises ValueError: If an unsupported data type is provided.
  48. """
  49. chunkers = {
  50. 'youtube_video': YoutubeVideoChunker(config),
  51. 'pdf_file': PdfFileChunker(config),
  52. 'web_page': WebPageChunker(config),
  53. 'qna_pair': QnaPairChunker(config),
  54. 'text': TextChunker(config),
  55. 'docx': DocxFileChunker(config),
  56. }
  57. if data_type in chunkers:
  58. return chunkers[data_type]
  59. else:
  60. raise ValueError(f"Unsupported data type: {data_type}")