Browse Source

Load local text files of any kind - code, txts, json etc (#1076)

Sidharth Mohanty 1 year ago
parent
commit
a544b4d3ff

+ 2 - 0
embedchain/data_formatter/data_formatter.py

@@ -76,6 +76,7 @@ class DataFormatter(JSONSerializable):
             DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
             DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
             DataType.SLACK: "embedchain.loaders.slack.SlackLoader",
+            DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader",
         }
 
         if data_type == DataType.CUSTOM or loader is not None:
@@ -120,6 +121,7 @@ class DataFormatter(JSONSerializable):
             DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
             DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.SLACK: "embedchain.chunkers.common_chunker.CommonChunker",
+            DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker",
         }
 
         if chunker is not None:

+ 2 - 2
embedchain/loaders/directory_loader.py

@@ -7,7 +7,7 @@ from embedchain.config import AddConfig
 from embedchain.data_formatter.data_formatter import DataFormatter
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.loaders.local_text import LocalTextLoader
+from embedchain.loaders.text_file import TextFileLoader
 from embedchain.utils import detect_datatype
 
 
@@ -58,4 +58,4 @@ class DirectoryLoader(BaseLoader):
             )
         except Exception as e:
             self.errors.append(f"Error processing {file_path}: {e}")
-            return LocalTextLoader()
+            return TextFileLoader()

+ 30 - 0
embedchain/loaders/text_file.py

@@ -0,0 +1,30 @@
+import hashlib
+import os
+
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.loaders.base_loader import BaseLoader
+
+
+@register_deserializable
+class TextFileLoader(BaseLoader):
+    def load_data(self, url: str):
+        """Load data from a text file located at a local path."""
+        if not os.path.exists(url):
+            raise FileNotFoundError(f"The file at {url} does not exist.")
+
+        with open(url, "r", encoding="utf-8") as file:
+            content = file.read()
+
+        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+
+        meta_data = {"url": url, "file_size": os.path.getsize(url), "file_type": url.split(".")[-1]}
+
+        return {
+            "doc_id": doc_id,
+            "data": [
+                {
+                    "content": content,
+                    "meta_data": meta_data,
+                }
+            ],
+        }

+ 2 - 0
embedchain/models/data_type.py

@@ -37,6 +37,7 @@ class IndirectDataType(Enum):
     BEEHIIV = "beehiiv"
     DIRECTORY = "directory"
     SLACK = "slack"
+    TEXT_FILE = "text_file"
 
 
 class SpecialDataType(Enum):
@@ -73,3 +74,4 @@ class DataType(Enum):
     BEEHIIV = IndirectDataType.BEEHIIV.value
     DIRECTORY = IndirectDataType.DIRECTORY.value
     SLACK = IndirectDataType.SLACK.value
+    TEXT_FILE = IndirectDataType.TEXT_FILE.value

+ 5 - 1
embedchain/utils.py

@@ -305,7 +305,7 @@ def detect_datatype(source: Any) -> DataType:
 
         if source.endswith(".txt"):
             logging.debug(f"Source of `{formatted_source}` detected as `text`.")
-            return DataType.TEXT
+            return DataType.TEXT_FILE
 
         if source.endswith(".pdf"):
             logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
@@ -331,6 +331,10 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `json`.")
             return DataType.JSON
 
+        if os.path.exists(source) and is_readable(open(source).read()):
+            logging.debug(f"Source of `{formatted_source}` detected as `text_file`.")
+            return DataType.TEXT_FILE
+
         # If the source is a valid file, that's not detectable as a type, an error is raised.
         # It does not fallback to text.
         raise ValueError(

+ 2 - 2
tests/embedchain/test_utils.py

@@ -85,10 +85,10 @@ class TestApp(unittest.TestCase):
             detect_datatype(["foo", "bar"])
 
     @patch("os.path.isfile")
-    def test_detect_datatype_regular_filesystem_file_not_detected(self, mock_isfile):
+    def test_detect_datatype_regular_filesystem_file_txt(self, mock_isfile):
         with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
             mock_isfile.return_value = True
-            self.assertEqual(detect_datatype(tmp.name), DataType.TEXT)
+            self.assertEqual(detect_datatype(tmp.name), DataType.TEXT_FILE)
 
     def test_detect_datatype_regular_filesystem_no_file(self):
         """Test that if a filepath is not actually an existing file, it is not handled as a file path."""