Ver código fonte

[Feature] Improve GitHub loader (#962)

Deven Patel 1 ano atrás
pai
commit
023a61446f
2 arquivos alterados com 22 adições e 2 exclusões
  1. 14 2
      embedchain/loaders/github.py
  2. 8 0
      embedchain/utils.py

+ 14 - 2
embedchain/loaders/github.py

@@ -3,6 +3,8 @@ import hashlib
 import logging
 import os
 
+from tqdm import tqdm
+
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.json import JSONLoader
 from embedchain.loaders.mdx import MdxLoader
@@ -53,14 +55,24 @@ class GithubLoader(BaseLoader):
 
             return data.get("data", [])
 
+        def _is_file_empty(file_path):
+            return os.path.getsize(file_path) == 0
+
+        def _is_whitelisted(file_path):
+            whitelisted_extensions = ["md", "txt", "html", "json", "py", "js", "jsx", "ts", "tsx", "mdx", "rst"]
+            _, file_extension = os.path.splitext(file_path)
+            return file_extension[1:] in whitelisted_extensions
+
         def _add_repo_files(repo_path: str):
             with concurrent.futures.ThreadPoolExecutor() as executor:
                 future_to_file = {
                     executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
                     for root, _, files in os.walk(repo_path)
                     for filename in files
-                }  # noqa: E501
-                for future in concurrent.futures.as_completed(future_to_file):
+                    if _is_whitelisted(os.path.join(root, filename))
+                    and not _is_file_empty(os.path.join(root, filename))  # noqa:E501
+                }
+                for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(future_to_file)):
                     file = future_to_file[future]
                     try:
                         results = future.result()

+ 8 - 0
embedchain/utils.py

@@ -216,6 +216,10 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
             return DataType.CSV
 
+        if url.path.endswith(".mdx") or url.path.endswith(".md"):
+            logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
+            return DataType.MDX
+
         if url.path.endswith(".docx"):
             logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
             return DataType.DOCX
@@ -292,6 +296,10 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
             return DataType.XML
 
+        if source.endswith(".mdx") or source.endswith(".md"):
+            logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
+            return DataType.MDX
+
         if source.endswith(".yaml"):
             with open(source, "r") as file:
                 yaml_content = yaml.safe_load(file)