|
@@ -3,6 +3,8 @@ import hashlib
|
|
|
import logging
|
|
|
import os
|
|
|
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
from embedchain.loaders.base_loader import BaseLoader
|
|
|
from embedchain.loaders.json import JSONLoader
|
|
|
from embedchain.loaders.mdx import MdxLoader
|
|
@@ -53,14 +55,24 @@ class GithubLoader(BaseLoader):
|
|
|
|
|
|
return data.get("data", [])
|
|
|
|
|
|
+ def _is_file_empty(file_path):
|
|
|
+ return os.path.getsize(file_path) == 0
|
|
|
+
|
|
|
+ def _is_whitelisted(file_path):
|
|
|
+ whitelisted_extensions = ["md", "txt", "html", "json", "py", "js", "jsx", "ts", "tsx", "mdx", "rst"]
|
|
|
+ _, file_extension = os.path.splitext(file_path)
|
|
|
+ return file_extension[1:] in whitelisted_extensions
|
|
|
+
|
|
|
def _add_repo_files(repo_path: str):
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
future_to_file = {
|
|
|
executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
|
|
|
for root, _, files in os.walk(repo_path)
|
|
|
for filename in files
|
|
|
- } # noqa: E501
|
|
|
- for future in concurrent.futures.as_completed(future_to_file):
|
|
|
+ if _is_whitelisted(os.path.join(root, filename))
|
|
|
+ and not _is_file_empty(os.path.join(root, filename)) # noqa:E501
|
|
|
+ }
|
|
|
+ for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(future_to_file)):
|
|
|
file = future_to_file[future]
|
|
|
try:
|
|
|
results = future.result()
|