Browse Source

[Feature] Add support for directory loader as data source (#1008)

Sidharth Mohanty 1 năm trước cách đây
mục cha
commit
9303a1bf81

+ 2 - 0
embedchain/data_formatter/data_formatter.py

@@ -74,6 +74,7 @@ class DataFormatter(JSONSerializable):
             DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
             DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
             DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
+            DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
         }
 
         if data_type == DataType.CUSTOM or loader is not None:
@@ -116,6 +117,7 @@ class DataFormatter(JSONSerializable):
             DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
             DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
+            DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
         }
 
         if chunker is not None:

+ 55 - 0
embedchain/loaders/directory_loader.py

@@ -0,0 +1,55 @@
+from pathlib import Path
+import hashlib
+import logging
+from typing import Optional, Dict, Any
+
+from embedchain.utils import detect_datatype
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.loaders.base_loader import BaseLoader
+from embedchain.loaders.local_text import LocalTextLoader
+from embedchain.data_formatter.data_formatter import DataFormatter
+from embedchain.config import AddConfig
+
+
+@register_deserializable
+class DirectoryLoader(BaseLoader):
+    """Load data from a directory."""
+
+    def __init__(self, config: Optional[Dict[str, Any]] = None):
+        super().__init__()
+        config = config or {}
+        self.recursive = config.get("recursive", True)
+        self.extensions = config.get("extensions", None)
+        self.errors = []
+
+    def load_data(self, path: str):
+        directory_path = Path(path)
+        if not directory_path.is_dir():
+            raise ValueError(f"Invalid path: {path}")
+
+        data_list = self._process_directory(directory_path)
+        doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
+
+        for error in self.errors:
+            logging.warn(error)
+
+        return {"doc_id": doc_id, "data": data_list}
+
+    def _process_directory(self, directory_path: Path):
+        data_list = []
+        for file_path in directory_path.rglob("*") if self.recursive else directory_path.glob("*"):
+            if file_path.is_file() and (not self.extensions or any(file_path.suffix == ext for ext in self.extensions)):
+                loader = self._predict_loader(file_path)
+                data_list.extend(loader.load_data(str(file_path))["data"])
+        return data_list
+
+    def _predict_loader(self, file_path: Path) -> BaseLoader:
+        try:
+            data_type = detect_datatype(str(file_path))
+            config = AddConfig()
+            return DataFormatter(data_type=data_type, config=config)._get_loader(
+                data_type=data_type, config=config.loader, loader=None
+            )
+        except Exception as e:
+            self.errors.append(f"Error processing {file_path}: {e}")
+            return LocalTextLoader()

+ 2 - 0
embedchain/models/data_type.py

@@ -35,6 +35,7 @@ class IndirectDataType(Enum):
     CUSTOM = "custom"
     RSSFEED = "rss_feed"
     BEEHIIV = "beehiiv"
+    DIRECTORY = "directory"
 
 
 class SpecialDataType(Enum):
@@ -69,3 +70,4 @@ class DataType(Enum):
     CUSTOM = IndirectDataType.CUSTOM.value
     RSSFEED = IndirectDataType.RSSFEED.value
     BEEHIIV = IndirectDataType.BEEHIIV.value
+    DIRECTORY = IndirectDataType.DIRECTORY.value

+ 9 - 2
embedchain/utils.py

@@ -196,8 +196,7 @@ def detect_datatype(source: Any) -> DataType:
     formatted_source = format_source(str(source), 30)
 
     if url:
-        from langchain.document_loaders.youtube import \
-            ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
+        from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -303,6 +302,14 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
             return DataType.MDX
 
+        if source.endswith(".txt"):
+            logging.debug(f"Source of `{formatted_source}` detected as `text`.")
+            return DataType.TEXT
+
+        if source.endswith(".pdf"):
+            logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
+            return DataType.PDF_FILE
+
         if source.endswith(".yaml"):
             with open(source, "r") as file:
                 yaml_content = yaml.safe_load(file)

+ 1 - 3
tests/embedchain/test_utils.py

@@ -86,11 +86,9 @@ class TestApp(unittest.TestCase):
 
     @patch("os.path.isfile")
     def test_detect_datatype_regular_filesystem_file_not_detected(self, mock_isfile):
-        """Test error if a valid file is referenced, but it isn't a valid data_type"""
         with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
             mock_isfile.return_value = True
-            with self.assertRaises(ValueError):
-                detect_datatype(tmp.name)
+            self.assertEqual(detect_datatype(tmp.name), DataType.TEXT)
 
     def test_detect_datatype_regular_filesystem_no_file(self):
         """Test that if a filepath is not actually an existing file, it is not handled as a file path."""