data_formatter.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from embedchain.chunkers.docx_file import DocxFileChunker
  2. from embedchain.chunkers.pdf_file import PdfFileChunker
  3. from embedchain.chunkers.qna_pair import QnaPairChunker
  4. from embedchain.chunkers.text import TextChunker
  5. from embedchain.chunkers.web_page import WebPageChunker
  6. from embedchain.chunkers.youtube_video import YoutubeVideoChunker
  7. from embedchain.config import AddConfig
  8. from embedchain.loaders.docx_file import DocxFileLoader
  9. from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
  10. from embedchain.loaders.local_text import LocalTextLoader
  11. from embedchain.loaders.pdf_file import PdfFileLoader
  12. from embedchain.loaders.web_page import WebPageLoader
  13. from embedchain.loaders.youtube_video import YoutubeVideoLoader
  14. class DataFormatter:
  15. """
  16. DataFormatter is an internal utility class which abstracts the mapping for
  17. loaders and chunkers to the data_type entered by the user in their
  18. .add or .add_local method call
  19. """
  20. def __init__(self, data_type: str, config: AddConfig):
  21. self.loader = self._get_loader(data_type, config.loader)
  22. self.chunker = self._get_chunker(data_type, config.chunker)
  23. def _get_loader(self, data_type, config):
  24. """
  25. Returns the appropriate data loader for the given data type.
  26. :param data_type: The type of the data to load.
  27. :return: The loader for the given data type.
  28. :raises ValueError: If an unsupported data type is provided.
  29. """
  30. loaders = {
  31. "youtube_video": YoutubeVideoLoader(),
  32. "pdf_file": PdfFileLoader(),
  33. "web_page": WebPageLoader(),
  34. "qna_pair": LocalQnaPairLoader(),
  35. "text": LocalTextLoader(),
  36. "docx": DocxFileLoader(),
  37. }
  38. if data_type in loaders:
  39. return loaders[data_type]
  40. else:
  41. raise ValueError(f"Unsupported data type: {data_type}")
  42. def _get_chunker(self, data_type, config):
  43. """
  44. Returns the appropriate chunker for the given data type.
  45. :param data_type: The type of the data to chunk.
  46. :return: The chunker for the given data type.
  47. :raises ValueError: If an unsupported data type is provided.
  48. """
  49. chunkers = {
  50. "youtube_video": YoutubeVideoChunker(config),
  51. "pdf_file": PdfFileChunker(config),
  52. "web_page": WebPageChunker(config),
  53. "qna_pair": QnaPairChunker(config),
  54. "text": TextChunker(config),
  55. "docx": DocxFileChunker(config),
  56. }
  57. if data_type in chunkers:
  58. return chunkers[data_type]
  59. else:
  60. raise ValueError(f"Unsupported data type: {data_type}")