|
@@ -1,7 +1,6 @@
|
|
|
import concurrent.futures
|
|
|
import hashlib
|
|
|
import logging
|
|
|
-import os
|
|
|
import re
|
|
|
import shlex
|
|
|
from typing import Any, Optional
|
|
@@ -14,7 +13,7 @@ from embedchain.utils.misc import clean_string
|
|
|
GITHUB_URL = "https://github.com"
|
|
|
GITHUB_API_URL = "https://api.github.com"
|
|
|
|
|
|
-VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
|
|
|
+VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion", "branch", "file"])
|
|
|
|
|
|
|
|
|
class GithubLoader(BaseLoader):
|
|
@@ -66,85 +65,56 @@ class GithubLoader(BaseLoader):
|
|
|
)
|
|
|
return data
|
|
|
|
|
|
- @staticmethod
|
|
|
- def _get_github_repo_data(repo_url: str):
|
|
|
- local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
|
|
|
- local_path = f"/tmp/{local_hash}"
|
|
|
+ def _get_github_repo_data(self, repo_name: str, branch_name: str = None, file_path: str = None) -> list[dict]:
|
|
|
+ """Get file contents from Repo"""
|
|
|
data = []
|
|
|
|
|
|
- def _get_repo_tree(repo_url: str, local_path: str):
|
|
|
- try:
|
|
|
- from git import Repo
|
|
|
- except ImportError as e:
|
|
|
- raise ValueError(
|
|
|
- "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`" # noqa: E501
|
|
|
- ) from e
|
|
|
+ repo = self.client.get_repo(repo_name)
|
|
|
+ repo_contents = repo.get_contents("")
|
|
|
|
|
|
- if os.path.exists(local_path):
|
|
|
- logging.info("Repository already exists. Fetching updates...")
|
|
|
- repo = Repo(local_path)
|
|
|
- logging.info("Fetch completed.")
|
|
|
- else:
|
|
|
- logging.info("Cloning repository...")
|
|
|
- repo = Repo.clone_from(repo_url, local_path)
|
|
|
- logging.info("Clone completed.")
|
|
|
- return repo.head.commit.tree
|
|
|
-
|
|
|
- def _get_repo_tree_contents(repo_path, tree, progress_bar):
|
|
|
- for subtree in tree:
|
|
|
- if subtree.type == "tree":
|
|
|
- _get_repo_tree_contents(repo_path, subtree, progress_bar)
|
|
|
+ if branch_name:
|
|
|
+ repo_contents = repo.get_contents("", ref=branch_name)
|
|
|
+ if file_path:
|
|
|
+ repo_contents = [repo.get_contents(file_path)]
|
|
|
+
|
|
|
+ with tqdm(desc="Loading files:", unit="item") as progress_bar:
|
|
|
+ while repo_contents:
|
|
|
+ file_content = repo_contents.pop(0)
|
|
|
+ if file_content.type == "dir":
|
|
|
+ try:
|
|
|
+ repo_contents.extend(repo.get_contents(file_content.path))
|
|
|
+ except Exception:
|
|
|
+ logging.warning(f"Failed to read directory: {file_content.path}")
|
|
|
+ progress_bar.update(1)
|
|
|
+ continue
|
|
|
else:
|
|
|
- assert subtree.type == "blob"
|
|
|
try:
|
|
|
- contents = subtree.data_stream.read().decode("utf-8")
|
|
|
+ file_text = file_content.decoded_content.decode()
|
|
|
except Exception:
|
|
|
- logging.warning(f"Failed to read file: {subtree.path}")
|
|
|
- progress_bar.update(1) if progress_bar else None
|
|
|
+ logging.warning(f"Failed to read file: {file_content.path}")
|
|
|
+ progress_bar.update(1)
|
|
|
continue
|
|
|
|
|
|
- url = f"{repo_url.rstrip('.git')}/blob/main/{subtree.path}"
|
|
|
+ file_path = file_content.path
|
|
|
data.append(
|
|
|
{
|
|
|
- "content": clean_string(contents),
|
|
|
+ "content": clean_string(file_text),
|
|
|
"meta_data": {
|
|
|
- "url": url,
|
|
|
+ "path": file_path,
|
|
|
},
|
|
|
}
|
|
|
)
|
|
|
- if progress_bar is not None:
|
|
|
- progress_bar.update(1)
|
|
|
|
|
|
- repo_tree = _get_repo_tree(repo_url, local_path)
|
|
|
- tree_list = list(repo_tree.traverse())
|
|
|
- with tqdm(total=len(tree_list), desc="Loading files:", unit="item") as progress_bar:
|
|
|
- _get_repo_tree_contents(local_path, repo_tree, progress_bar)
|
|
|
+ progress_bar.update(1)
|
|
|
|
|
|
return data
|
|
|
|
|
|
def _github_search_repo(self, query: str) -> list[dict]:
|
|
|
"""Search GitHub repo."""
|
|
|
- data = []
|
|
|
+
|
|
|
logging.info(f"Searching github repos with query: {query}")
|
|
|
- results = self.client.search_repositories(query)
|
|
|
- # Add repo urls and descriptions
|
|
|
- urls = list(map(lambda x: x.html_url, results))
|
|
|
- descriptions = list(map(lambda x: x.description, results))
|
|
|
- data.append(
|
|
|
- {
|
|
|
- "content": clean_string(desc),
|
|
|
- "meta_data": {
|
|
|
- "url": url,
|
|
|
- },
|
|
|
- }
|
|
|
- for url, desc in zip(urls, descriptions)
|
|
|
- )
|
|
|
-
|
|
|
- # Add repo contents
|
|
|
- for result in results:
|
|
|
- clone_url = result.clone_url
|
|
|
- logging.info(f"Cloning repository: {clone_url}")
|
|
|
- data = self._get_github_repo_data(clone_url)
|
|
|
+ updated_query = query.split(":")[-1]
|
|
|
+ data = self._get_github_repo_data(updated_query)
|
|
|
return data
|
|
|
|
|
|
def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
|
|
@@ -222,6 +192,43 @@ class GithubLoader(BaseLoader):
|
|
|
)
|
|
|
return data
|
|
|
|
|
|
+ def _get_github_repo_branch(self, query: str, type: str) -> list[dict]:
|
|
|
+ """Get file contents for specific branch"""
|
|
|
+
|
|
|
+ logging.info(f"Searching github repo for query: {query} is:{type}")
|
|
|
+ pattern = r"repo:(\S+) name:(\S+)"
|
|
|
+ match = re.search(pattern, query)
|
|
|
+
|
|
|
+ if match:
|
|
|
+ repo_name = match.group(1)
|
|
|
+ branch_name = match.group(2)
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ f"Repository name and Branch name not found, instead found this \
|
|
|
+ Repo: {repo_name}, Branch: {branch_name}"
|
|
|
+ )
|
|
|
+
|
|
|
+ data = self._get_github_repo_data(repo_name=repo_name, branch_name=branch_name)
|
|
|
+ return data
|
|
|
+
|
|
|
+ def _get_github_repo_file(self, query: str, type: str) -> list[dict]:
|
|
|
+ """Get specific file content"""
|
|
|
+
|
|
|
+ logging.info(f"Searching github repo for query: {query} is:{type}")
|
|
|
+ pattern = r"repo:(\S+) path:(\S+)"
|
|
|
+ match = re.search(pattern, query)
|
|
|
+
|
|
|
+ if match:
|
|
|
+ repo_name = match.group(1)
|
|
|
+ file_path = match.group(2)
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ f"Repository name and File name not found, instead found this Repo: {repo_name}, File: {file_path}"
|
|
|
+ )
|
|
|
+
|
|
|
+ data = self._get_github_repo_data(repo_name=repo_name, file_path=file_path)
|
|
|
+ return data
|
|
|
+
|
|
|
def _search_github_data(self, search_type: str, query: str):
|
|
|
"""Search github data."""
|
|
|
if search_type == "code":
|
|
@@ -232,6 +239,10 @@ class GithubLoader(BaseLoader):
|
|
|
data = self._github_search_issues_and_pr(query, search_type)
|
|
|
elif search_type == "pr":
|
|
|
data = self._github_search_issues_and_pr(query, search_type)
|
|
|
+ elif search_type == "branch":
|
|
|
+ data = self._get_github_repo_branch(query, search_type)
|
|
|
+ elif search_type == "file":
|
|
|
+ data = self._get_github_repo_file(query, search_type)
|
|
|
elif search_type == "discussion":
|
|
|
raise ValueError("GithubLoader does not support searching discussions yet.")
|
|
|
else:
|