|
@@ -0,0 +1,55 @@
|
|
|
+from pathlib import Path
|
|
|
+import hashlib
|
|
|
+import logging
|
|
|
+from typing import Optional, Dict, Any
|
|
|
+
|
|
|
+from embedchain.utils import detect_datatype
|
|
|
+from embedchain.helpers.json_serializable import register_deserializable
|
|
|
+from embedchain.loaders.base_loader import BaseLoader
|
|
|
+from embedchain.loaders.local_text import LocalTextLoader
|
|
|
+from embedchain.data_formatter.data_formatter import DataFormatter
|
|
|
+from embedchain.config import AddConfig
|
|
|
+
|
|
|
+
|
|
|
+@register_deserializable
|
|
|
+class DirectoryLoader(BaseLoader):
|
|
|
+ """Load data from a directory."""
|
|
|
+
|
|
|
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
|
+ super().__init__()
|
|
|
+ config = config or {}
|
|
|
+ self.recursive = config.get("recursive", True)
|
|
|
+ self.extensions = config.get("extensions", None)
|
|
|
+ self.errors = []
|
|
|
+
|
|
|
+ def load_data(self, path: str):
|
|
|
+ directory_path = Path(path)
|
|
|
+ if not directory_path.is_dir():
|
|
|
+ raise ValueError(f"Invalid path: {path}")
|
|
|
+
|
|
|
+ data_list = self._process_directory(directory_path)
|
|
|
+ doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
|
|
|
+
|
|
|
+ for error in self.errors:
|
|
|
+ logging.warn(error)
|
|
|
+
|
|
|
+ return {"doc_id": doc_id, "data": data_list}
|
|
|
+
|
|
|
+ def _process_directory(self, directory_path: Path):
|
|
|
+ data_list = []
|
|
|
+ for file_path in directory_path.rglob("*") if self.recursive else directory_path.glob("*"):
|
|
|
+ if file_path.is_file() and (not self.extensions or any(file_path.suffix == ext for ext in self.extensions)):
|
|
|
+ loader = self._predict_loader(file_path)
|
|
|
+ data_list.extend(loader.load_data(str(file_path))["data"])
|
|
|
+ return data_list
|
|
|
+
|
|
|
+ def _predict_loader(self, file_path: Path) -> BaseLoader:
|
|
|
+ try:
|
|
|
+ data_type = detect_datatype(str(file_path))
|
|
|
+ config = AddConfig()
|
|
|
+ return DataFormatter(data_type=data_type, config=config)._get_loader(
|
|
|
+ data_type=data_type, config=config.loader, loader=None
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ self.errors.append(f"Error processing {file_path}: {e}")
|
|
|
+ return LocalTextLoader()
|