123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- import concurrent.futures
- import hashlib
- import logging
- import os
- from embedchain.loaders.base_loader import BaseLoader
- from embedchain.loaders.json import JSONLoader
- from embedchain.loaders.mdx import MdxLoader
- from embedchain.loaders.unstructured_file import UnstructuredLoader
- from embedchain.utils import detect_datatype
- class GithubLoader(BaseLoader):
- def load_data(self, repo_url):
- """Load data from a git repo."""
- try:
- from git import Repo
- except ImportError as e:
- raise ValueError(
- "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[git]'`"
- ) from e
- mdx_loader = MdxLoader()
- json_loader = JSONLoader()
- unstructured_loader = UnstructuredLoader()
- data = []
- data_urls = []
- def _fetch_or_clone_repo(repo_url: str, local_path: str):
- if os.path.exists(local_path):
- logging.info("Repository already exists. Fetching updates...")
- repo = Repo(local_path)
- origin = repo.remotes.origin
- origin.fetch()
- logging.info("Fetch completed.")
- else:
- logging.info("Cloning repository...")
- Repo.clone_from(repo_url, local_path)
- logging.info("Clone completed.")
- def _load_file(file_path: str):
- try:
- data_type = detect_datatype(file_path).value
- except Exception:
- data_type = "unstructured"
- if data_type == "mdx":
- data = mdx_loader.load_data(file_path)
- elif data_type == "json":
- data = json_loader.load_data(file_path)
- else:
- data = unstructured_loader.load_data(file_path)
- return data.get("data", [])
- def _add_repo_files(repo_path: str):
- with concurrent.futures.ThreadPoolExecutor() as executor:
- future_to_file = {
- executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
- for root, _, files in os.walk(repo_path)
- for filename in files
- } # noqa: E501
- for future in concurrent.futures.as_completed(future_to_file):
- file = future_to_file[future]
- try:
- results = future.result()
- if results:
- data.extend(results)
- data_urls.extend([result.get("meta_data").get("url") for result in results])
- except Exception as e:
- logging.error(f"Failed to process {file}: {e}")
- source_hash = hashlib.sha256(repo_url.encode()).hexdigest()
- repo_path = f"/tmp/{source_hash}"
- _fetch_or_clone_repo(repo_url=repo_url, local_path=repo_path)
- _add_repo_files(repo_path)
- doc_id = hashlib.sha256((repo_url + ", ".join(data_urls)).encode()).hexdigest()
- return {
- "doc_id": doc_id,
- "data": data,
- }
|