github.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import concurrent.futures
  2. import hashlib
  3. import logging
  4. import os
  5. from tqdm import tqdm
  6. from embedchain.loaders.base_loader import BaseLoader
  7. from embedchain.loaders.json import JSONLoader
  8. from embedchain.loaders.mdx import MdxLoader
  9. from embedchain.loaders.unstructured_file import UnstructuredLoader
  10. from embedchain.utils import detect_datatype
  11. class GithubLoader(BaseLoader):
  12. def load_data(self, repo_url):
  13. """Load data from a git repo."""
  14. try:
  15. from git import Repo
  16. except ImportError as e:
  17. raise ValueError(
  18. "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[git]'`"
  19. ) from e
  20. mdx_loader = MdxLoader()
  21. json_loader = JSONLoader()
  22. unstructured_loader = UnstructuredLoader()
  23. data = []
  24. data_urls = []
  25. def _fetch_or_clone_repo(repo_url: str, local_path: str):
  26. if os.path.exists(local_path):
  27. logging.info("Repository already exists. Fetching updates...")
  28. repo = Repo(local_path)
  29. origin = repo.remotes.origin
  30. origin.fetch()
  31. logging.info("Fetch completed.")
  32. else:
  33. logging.info("Cloning repository...")
  34. Repo.clone_from(repo_url, local_path)
  35. logging.info("Clone completed.")
  36. def _load_file(file_path: str):
  37. try:
  38. data_type = detect_datatype(file_path).value
  39. except Exception:
  40. data_type = "unstructured"
  41. if data_type == "mdx":
  42. data = mdx_loader.load_data(file_path)
  43. elif data_type == "json":
  44. data = json_loader.load_data(file_path)
  45. else:
  46. data = unstructured_loader.load_data(file_path)
  47. return data.get("data", [])
  48. def _is_file_empty(file_path):
  49. return os.path.getsize(file_path) == 0
  50. def _is_whitelisted(file_path):
  51. whitelisted_extensions = ["md", "txt", "html", "json", "py", "js", "jsx", "ts", "tsx", "mdx", "rst"]
  52. _, file_extension = os.path.splitext(file_path)
  53. return file_extension[1:] in whitelisted_extensions
  54. def _add_repo_files(repo_path: str):
  55. with concurrent.futures.ThreadPoolExecutor() as executor:
  56. future_to_file = {
  57. executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
  58. for root, _, files in os.walk(repo_path)
  59. for filename in files
  60. if _is_whitelisted(os.path.join(root, filename))
  61. and not _is_file_empty(os.path.join(root, filename)) # noqa:E501
  62. }
  63. for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(future_to_file)):
  64. file = future_to_file[future]
  65. try:
  66. results = future.result()
  67. if results:
  68. data.extend(results)
  69. data_urls.extend([result.get("meta_data").get("url") for result in results])
  70. except Exception as e:
  71. logging.warn(f"Failed to process {file}: {e}")
  72. source_hash = hashlib.sha256(repo_url.encode()).hexdigest()
  73. repo_path = f"/tmp/{source_hash}"
  74. _fetch_or_clone_repo(repo_url=repo_url, local_path=repo_path)
  75. _add_repo_files(repo_path)
  76. doc_id = hashlib.sha256((repo_url + ", ".join(data_urls)).encode()).hexdigest()
  77. return {
  78. "doc_id": doc_id,
  79. "data": data,
  80. }