github.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import concurrent.futures
  2. import hashlib
  3. import logging
  4. import os
  5. from embedchain.loaders.base_loader import BaseLoader
  6. from embedchain.loaders.json import JSONLoader
  7. from embedchain.loaders.mdx import MdxLoader
  8. from embedchain.loaders.unstructured_file import UnstructuredLoader
  9. from embedchain.utils import detect_datatype
  10. class GithubLoader(BaseLoader):
  11. def load_data(self, repo_url):
  12. """Load data from a git repo."""
  13. try:
  14. from git import Repo
  15. except ImportError as e:
  16. raise ValueError(
  17. "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[git]'`"
  18. ) from e
  19. mdx_loader = MdxLoader()
  20. json_loader = JSONLoader()
  21. unstructured_loader = UnstructuredLoader()
  22. data = []
  23. data_urls = []
  24. def _fetch_or_clone_repo(repo_url: str, local_path: str):
  25. if os.path.exists(local_path):
  26. logging.info("Repository already exists. Fetching updates...")
  27. repo = Repo(local_path)
  28. origin = repo.remotes.origin
  29. origin.fetch()
  30. logging.info("Fetch completed.")
  31. else:
  32. logging.info("Cloning repository...")
  33. Repo.clone_from(repo_url, local_path)
  34. logging.info("Clone completed.")
  35. def _load_file(file_path: str):
  36. try:
  37. data_type = detect_datatype(file_path).value
  38. except Exception:
  39. data_type = "unstructured"
  40. if data_type == "mdx":
  41. data = mdx_loader.load_data(file_path)
  42. elif data_type == "json":
  43. data = json_loader.load_data(file_path)
  44. else:
  45. data = unstructured_loader.load_data(file_path)
  46. return data.get("data", [])
  47. def _add_repo_files(repo_path: str):
  48. with concurrent.futures.ThreadPoolExecutor() as executor:
  49. future_to_file = {
  50. executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
  51. for root, _, files in os.walk(repo_path)
  52. for filename in files
  53. } # noqa: E501
  54. for future in concurrent.futures.as_completed(future_to_file):
  55. file = future_to_file[future]
  56. try:
  57. results = future.result()
  58. if results:
  59. data.extend(results)
  60. data_urls.extend([result.get("meta_data").get("url") for result in results])
  61. except Exception as e:
  62. logging.warn(f"Failed to process {file}: {e}")
  63. source_hash = hashlib.sha256(repo_url.encode()).hexdigest()
  64. repo_path = f"/tmp/{source_hash}"
  65. _fetch_or_clone_repo(repo_url=repo_url, local_path=repo_path)
  66. _add_repo_files(repo_path)
  67. doc_id = hashlib.sha256((repo_url + ", ".join(data_urls)).encode()).hexdigest()
  68. return {
  69. "doc_id": doc_id,
  70. "data": data,
  71. }