directory_loader.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import hashlib
  2. import logging
  3. from pathlib import Path
  4. from typing import Any, Optional
  5. from embedchain.config import AddConfig
  6. from embedchain.data_formatter.data_formatter import DataFormatter
  7. from embedchain.helpers.json_serializable import register_deserializable
  8. from embedchain.loaders.base_loader import BaseLoader
  9. from embedchain.loaders.text_file import TextFileLoader
  10. from embedchain.utils.misc import detect_datatype
  11. logger = logging.getLogger(__name__)
  12. @register_deserializable
  13. class DirectoryLoader(BaseLoader):
  14. """Load data from a directory."""
  15. def __init__(self, config: Optional[dict[str, Any]] = None):
  16. super().__init__()
  17. config = config or {}
  18. self.recursive = config.get("recursive", True)
  19. self.extensions = config.get("extensions", None)
  20. self.errors = []
  21. def load_data(self, path: str):
  22. directory_path = Path(path)
  23. if not directory_path.is_dir():
  24. raise ValueError(f"Invalid path: {path}")
  25. logger.info(f"Loading data from directory: {path}")
  26. data_list = self._process_directory(directory_path)
  27. doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
  28. for error in self.errors:
  29. logger.warning(error)
  30. return {"doc_id": doc_id, "data": data_list}
  31. def _process_directory(self, directory_path: Path):
  32. data_list = []
  33. for file_path in directory_path.rglob("*") if self.recursive else directory_path.glob("*"):
  34. # don't include dotfiles
  35. if file_path.name.startswith("."):
  36. continue
  37. if file_path.is_file() and (not self.extensions or any(file_path.suffix == ext for ext in self.extensions)):
  38. loader = self._predict_loader(file_path)
  39. data_list.extend(loader.load_data(str(file_path))["data"])
  40. elif file_path.is_dir():
  41. logger.info(f"Loading data from directory: {file_path}")
  42. return data_list
  43. def _predict_loader(self, file_path: Path) -> BaseLoader:
  44. try:
  45. data_type = detect_datatype(str(file_path))
  46. config = AddConfig()
  47. return DataFormatter(data_type=data_type, config=config)._get_loader(
  48. data_type=data_type, config=config.loader, loader=None
  49. )
  50. except Exception as e:
  51. self.errors.append(f"Error processing {file_path}: {e}")
  52. return TextFileLoader()