github.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import concurrent.futures
  2. import hashlib
  3. import logging
  4. import os
  5. from tqdm import tqdm
  6. from embedchain.loaders.base_loader import BaseLoader
  7. from embedchain.loaders.json import JSONLoader
  8. from embedchain.loaders.mdx import MdxLoader
  9. from embedchain.utils import detect_datatype
  10. def _load_file_data(path):
  11. data = []
  12. data_content = []
  13. try:
  14. with open(path, "rb") as f:
  15. content = f.read().decode("utf-8")
  16. except Exception as e:
  17. print(f"Error reading file {path}: {e}")
  18. raise ValueError(f"Failed to read file {path}")
  19. meta_data = {}
  20. meta_data["url"] = path
  21. data.append(
  22. {
  23. "content": content,
  24. "meta_data": meta_data,
  25. }
  26. )
  27. data_content.append(content)
  28. doc_id = hashlib.sha256((" ".join(data_content) + path).encode()).hexdigest()
  29. return {
  30. "doc_id": doc_id,
  31. "data": data,
  32. }
  33. class GithubLoader(BaseLoader):
  34. def load_data(self, repo_url):
  35. """Load data from a git repo."""
  36. try:
  37. from git import Repo
  38. except ImportError as e:
  39. raise ValueError(
  40. "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[git]'`"
  41. ) from e
  42. mdx_loader = MdxLoader()
  43. json_loader = JSONLoader()
  44. data = []
  45. data_urls = []
  46. def _fetch_or_clone_repo(repo_url: str, local_path: str):
  47. if os.path.exists(local_path):
  48. logging.info("Repository already exists. Fetching updates...")
  49. repo = Repo(local_path)
  50. origin = repo.remotes.origin
  51. origin.fetch()
  52. logging.info("Fetch completed.")
  53. else:
  54. logging.info("Cloning repository...")
  55. Repo.clone_from(repo_url, local_path)
  56. logging.info("Clone completed.")
  57. def _load_file(file_path: str):
  58. try:
  59. data_type = detect_datatype(file_path).value
  60. except Exception:
  61. data_type = "unstructured"
  62. if data_type == "mdx":
  63. data = mdx_loader.load_data(file_path)
  64. elif data_type == "json":
  65. data = json_loader.load_data(file_path)
  66. else:
  67. data = _load_file_data(file_path)
  68. return data.get("data", [])
  69. def _is_file_empty(file_path):
  70. return os.path.getsize(file_path) == 0
  71. def _is_whitelisted(file_path):
  72. whitelisted_extensions = ["md", "txt", "html", "json", "py", "js", "jsx", "ts", "tsx", "mdx", "rst"]
  73. _, file_extension = os.path.splitext(file_path)
  74. return file_extension[1:] in whitelisted_extensions
  75. def _add_repo_files(repo_path: str):
  76. with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
  77. future_to_file = {
  78. executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
  79. for root, _, files in os.walk(repo_path)
  80. for filename in files
  81. if _is_whitelisted(os.path.join(root, filename))
  82. and not _is_file_empty(os.path.join(root, filename)) # noqa:E501
  83. }
  84. for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(future_to_file)):
  85. file = future_to_file[future]
  86. try:
  87. results = future.result()
  88. if results:
  89. data.extend(results)
  90. data_urls.extend([result.get("meta_data").get("url") for result in results])
  91. except Exception as e:
  92. logging.warn(f"Failed to process {file}: {e}")
  93. source_hash = hashlib.sha256(repo_url.encode()).hexdigest()
  94. repo_path = f"/tmp/{source_hash}"
  95. _fetch_or_clone_repo(repo_url=repo_url, local_path=repo_path)
  96. _add_repo_files(repo_path)
  97. doc_id = hashlib.sha256((repo_url + ", ".join(data_urls)).encode()).hexdigest()
  98. return {
  99. "doc_id": doc_id,
  100. "data": data,
  101. }