123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- import concurrent.futures
- import hashlib
- import logging
- import os
- import re
- import shlex
- from typing import Any, Dict, Optional
- from tqdm import tqdm
- from embedchain.loaders.base_loader import BaseLoader
- from embedchain.utils import clean_string
- GITHUB_URL = "https://github.com"
- GITHUB_API_URL = "https://api.github.com"
- VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
- class GithubLoader(BaseLoader):
- """Load data from github search query."""
- def __init__(self, config: Optional[Dict[str, Any]] = None):
- super().__init__()
- if not config:
- raise ValueError(
- "GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
- )
- try:
- from github import Github
- except ImportError as e:
- raise ValueError(
- "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`"
- ) from e
- self.config = config
- token = config.get("token")
- if not token:
- raise ValueError(
- "GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
- )
- try:
- self.client = Github(token)
- except Exception as e:
- logging.error(f"GithubLoader failed to initialize client: {e}")
- self.client = None
- def _github_search_code(self, query: str):
- """Search github code."""
- data = []
- results = self.client.search_code(query)
- for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
- url = result.html_url
- logging.info(f"Added data from url: {url}")
- content = result.decoded_content.decode("utf-8")
- metadata = {
- "url": url,
- }
- data.append(
- {
- "content": clean_string(content),
- "meta_data": metadata,
- }
- )
- return data
- def _get_github_repo_data(self, repo_url: str):
- local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
- local_path = f"/tmp/{local_hash}"
- data = []
- def _get_repo_tree(repo_url: str, local_path: str):
- try:
- from git import Repo
- except ImportError as e:
- raise ValueError(
- "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`" # noqa: E501
- ) from e
- if os.path.exists(local_path):
- logging.info("Repository already exists. Fetching updates...")
- repo = Repo(local_path)
- logging.info("Fetch completed.")
- else:
- logging.info("Cloning repository...")
- Repo.clone_from(repo_url, local_path)
- repo = Repo.clone_from(repo_url, local_path)
- logging.info("Clone completed.")
- return repo.head.commit.tree
- def _get_repo_tree_contents(repo_path, tree, progress_bar):
- for subtree in tree:
- if subtree.type == "tree":
- _get_repo_tree_contents(repo_path, subtree, progress_bar)
- else:
- assert subtree.type == "blob"
- try:
- contents = subtree.data_stream.read().decode("utf-8")
- except Exception:
- logging.warning(f"Failed to read file: {subtree.path}")
- progress_bar.update(1) if progress_bar else None
- continue
- url = f"{repo_url.rstrip('.git')}/blob/main/{subtree.path}"
- data.append(
- {
- "content": clean_string(contents),
- "meta_data": {
- "url": url,
- },
- }
- )
- if progress_bar is not None:
- progress_bar.update(1)
- repo_tree = _get_repo_tree(repo_url, local_path)
- tree_list = list(repo_tree.traverse())
- with tqdm(total=len(tree_list), desc="Loading files:", unit="item") as progress_bar:
- _get_repo_tree_contents(local_path, repo_tree, progress_bar)
- return data
- def _github_search_repo(self, query: str):
- """Search github repo."""
- data = []
- logging.info(f"Searching github repos with query: {query}")
- results = self.client.search_repositories(query)
- # Add repo urls and descriptions
- urls = list(map(lambda x: x.html_url, results))
- discriptions = list(map(lambda x: x.description, results))
- data.append(
- {
- "content": clean_string(desc),
- "meta_data": {
- "url": url,
- },
- }
- for url, desc in zip(urls, discriptions)
- )
- # Add repo contents
- for result in results:
- clone_url = result.clone_url
- logging.info(f"Cloning repository: {clone_url}")
- data = self._get_github_repo_data(clone_url)
- return data
- def _github_search_issues_and_pr(self, query: str, type: str):
- """Search github issues and PRs."""
- data = []
- query = f"{query} is:{type}"
- logging.info(f"Searching github for query: {query}")
- results = self.client.search_issues(query)
- logging.info(f"Total results: {results.totalCount}")
- for result in tqdm(results, total=results.totalCount, desc=f"Loading {type} from github"):
- url = result.html_url
- title = result.title
- body = result.body
- if not body:
- logging.warn(f"Skipping issue because empty content for: {url}")
- continue
- labels = " ".join([label.name for label in result.labels])
- issue_comments = result.get_comments()
- comments = []
- comments_created_at = []
- for comment in issue_comments:
- comments_created_at.append(str(comment.created_at))
- comments.append(f"{comment.user.name}:{comment.body}")
- content = "\n".join([title, labels, body, *comments])
- metadata = {
- "url": url,
- "created_at": str(result.created_at),
- "comments_created_at": " ".join(comments_created_at),
- }
- data.append(
- {
- "content": clean_string(content),
- "meta_data": metadata,
- }
- )
- return data
- # need to test more for discussion
- def _github_search_discussions(self, query: str):
- """Search github discussions."""
- data = []
- query = f"{query} is:discussion"
- logging.info(f"Searching github repo for query: {query}")
- repos_results = self.client.search_repositories(query)
- logging.info(f"Total repos found: {repos_results.totalCount}")
- for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
- teams = repo_result.get_teams()
- # import pdb; pdb.set_trace()
- for team in teams:
- team_discussions = team.get_discussions()
- for discussion in team_discussions:
- url = discussion.html_url
- title = discussion.title
- body = discussion.body
- if not body:
- logging.warn(f"Skipping discussion because empty content for: {url}")
- continue
- comments = []
- comments_created_at = []
- print("Discussion comments: ", discussion.comments_url)
- content = "\n".join([title, body, *comments])
- metadata = {
- "url": url,
- "created_at": str(discussion.created_at),
- "comments_created_at": " ".join(comments_created_at),
- }
- data.append(
- {
- "content": clean_string(content),
- "meta_data": metadata,
- }
- )
- return data
- def _search_github_data(self, search_type: str, query: str):
- """Search github data."""
- if search_type == "code":
- data = self._github_search_code(query)
- elif search_type == "repo":
- data = self._github_search_repo(query)
- elif search_type == "issue":
- data = self._github_search_issues_and_pr(query, search_type)
- elif search_type == "pr":
- data = self._github_search_issues_and_pr(query, search_type)
- elif search_type == "discussion":
- raise ValueError("GithubLoader does not support searching discussions yet.")
- return data
- def _get_valid_github_query(self, query: str):
- """Check if query is valid and return search types and valid github query."""
- query_terms = shlex.split(query)
- # query must provide repo to load data from
- if len(query_terms) < 1 or "repo:" not in query:
- raise ValueError(
- "GithubLoader requires a search query with `repo:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
- )
- github_query = []
- types = set()
- type_pattern = r"type:([a-zA-Z,]+)"
- for term in query_terms:
- term_match = re.search(type_pattern, term)
- if term_match:
- search_types = term_match.group(1).split(",")
- types.update(search_types)
- else:
- github_query.append(term)
- # query must provide search type
- if len(types) == 0:
- raise ValueError(
- "GithubLoader requires a search query with `type:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
- )
- for search_type in search_types:
- if search_type not in VALID_SEARCH_TYPES:
- raise ValueError(
- f"Invalid search type: {search_type}. Valid types are: {', '.join(VALID_SEARCH_TYPES)}"
- )
- query = " ".join(github_query)
- return types, query
- def load_data(self, search_query: str, max_results: int = 1000):
- """Load data from github search query."""
- if not self.client:
- raise ValueError(
- "GithubLoader client is not initialized, data will not be loaded. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
- )
- search_types, query = self._get_valid_github_query(search_query)
- logging.info(f"Searching github for query: {query}, with types: {', '.join(search_types)}")
- data = []
- with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
- futures_map = executor.map(self._search_github_data, search_types, [query] * len(search_types))
- for search_data in tqdm(futures_map, total=len(search_types), desc="Searching data from github"):
- data.extend(search_data)
- return {
- "doc_id": hashlib.sha256(query.encode()).hexdigest(),
- "data": data,
- }
|