github.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import concurrent.futures
  2. import hashlib
  3. import logging
  4. import re
  5. import shlex
  6. from typing import Any, Optional
  7. from tqdm import tqdm
  8. from embedchain.loaders.base_loader import BaseLoader
  9. from embedchain.utils.misc import clean_string
  10. GITHUB_URL = "https://github.com"
  11. GITHUB_API_URL = "https://api.github.com"
  12. VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion", "branch", "file"])
  13. class GithubLoader(BaseLoader):
  14. """Load data from GitHub search query."""
  15. def __init__(self, config: Optional[dict[str, Any]] = None):
  16. super().__init__()
  17. if not config:
  18. raise ValueError(
  19. "GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
  20. )
  21. try:
  22. from github import Github
  23. except ImportError as e:
  24. raise ValueError(
  25. "GithubLoader requires extra dependencies. \
  26. Install with `pip install gitpython==3.1.38 PyGithub==1.59.1`"
  27. ) from e
  28. self.config = config
  29. token = config.get("token")
  30. if not token:
  31. raise ValueError(
  32. "GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
  33. )
  34. try:
  35. self.client = Github(token)
  36. except Exception as e:
  37. logging.error(f"GithubLoader failed to initialize client: {e}")
  38. self.client = None
  39. def _github_search_code(self, query: str):
  40. """Search GitHub code."""
  41. data = []
  42. results = self.client.search_code(query)
  43. for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
  44. url = result.html_url
  45. logging.info(f"Added data from url: {url}")
  46. content = result.decoded_content.decode("utf-8")
  47. metadata = {
  48. "url": url,
  49. }
  50. data.append(
  51. {
  52. "content": clean_string(content),
  53. "meta_data": metadata,
  54. }
  55. )
  56. return data
  57. def _get_github_repo_data(self, repo_name: str, branch_name: str = None, file_path: str = None) -> list[dict]:
  58. """Get file contents from Repo"""
  59. data = []
  60. repo = self.client.get_repo(repo_name)
  61. repo_contents = repo.get_contents("")
  62. if branch_name:
  63. repo_contents = repo.get_contents("", ref=branch_name)
  64. if file_path:
  65. repo_contents = [repo.get_contents(file_path)]
  66. with tqdm(desc="Loading files:", unit="item") as progress_bar:
  67. while repo_contents:
  68. file_content = repo_contents.pop(0)
  69. if file_content.type == "dir":
  70. try:
  71. repo_contents.extend(repo.get_contents(file_content.path))
  72. except Exception:
  73. logging.warning(f"Failed to read directory: {file_content.path}")
  74. progress_bar.update(1)
  75. continue
  76. else:
  77. try:
  78. file_text = file_content.decoded_content.decode()
  79. except Exception:
  80. logging.warning(f"Failed to read file: {file_content.path}")
  81. progress_bar.update(1)
  82. continue
  83. file_path = file_content.path
  84. data.append(
  85. {
  86. "content": clean_string(file_text),
  87. "meta_data": {
  88. "path": file_path,
  89. },
  90. }
  91. )
  92. progress_bar.update(1)
  93. return data
  94. def _github_search_repo(self, query: str) -> list[dict]:
  95. """Search GitHub repo."""
  96. logging.info(f"Searching github repos with query: {query}")
  97. updated_query = query.split(":")[-1]
  98. data = self._get_github_repo_data(updated_query)
  99. return data
  100. def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
  101. """Search GitHub issues and PRs."""
  102. data = []
  103. query = f"{query} is:{type}"
  104. logging.info(f"Searching github for query: {query}")
  105. results = self.client.search_issues(query)
  106. logging.info(f"Total results: {results.totalCount}")
  107. for result in tqdm(results, total=results.totalCount, desc=f"Loading {type} from github"):
  108. url = result.html_url
  109. title = result.title
  110. body = result.body
  111. if not body:
  112. logging.warning(f"Skipping issue because empty content for: {url}")
  113. continue
  114. labels = " ".join([label.name for label in result.labels])
  115. issue_comments = result.get_comments()
  116. comments = []
  117. comments_created_at = []
  118. for comment in issue_comments:
  119. comments_created_at.append(str(comment.created_at))
  120. comments.append(f"{comment.user.name}:{comment.body}")
  121. content = "\n".join([title, labels, body, *comments])
  122. metadata = {
  123. "url": url,
  124. "created_at": str(result.created_at),
  125. "comments_created_at": " ".join(comments_created_at),
  126. }
  127. data.append(
  128. {
  129. "content": clean_string(content),
  130. "meta_data": metadata,
  131. }
  132. )
  133. return data
  134. # need to test more for discussion
  135. def _github_search_discussions(self, query: str):
  136. """Search GitHub discussions."""
  137. data = []
  138. query = f"{query} is:discussion"
  139. logging.info(f"Searching github repo for query: {query}")
  140. repos_results = self.client.search_repositories(query)
  141. logging.info(f"Total repos found: {repos_results.totalCount}")
  142. for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
  143. teams = repo_result.get_teams()
  144. for team in teams:
  145. team_discussions = team.get_discussions()
  146. for discussion in team_discussions:
  147. url = discussion.html_url
  148. title = discussion.title
  149. body = discussion.body
  150. if not body:
  151. logging.warning(f"Skipping discussion because empty content for: {url}")
  152. continue
  153. comments = []
  154. comments_created_at = []
  155. print("Discussion comments: ", discussion.comments_url)
  156. content = "\n".join([title, body, *comments])
  157. metadata = {
  158. "url": url,
  159. "created_at": str(discussion.created_at),
  160. "comments_created_at": " ".join(comments_created_at),
  161. }
  162. data.append(
  163. {
  164. "content": clean_string(content),
  165. "meta_data": metadata,
  166. }
  167. )
  168. return data
  169. def _get_github_repo_branch(self, query: str, type: str) -> list[dict]:
  170. """Get file contents for specific branch"""
  171. logging.info(f"Searching github repo for query: {query} is:{type}")
  172. pattern = r"repo:(\S+) name:(\S+)"
  173. match = re.search(pattern, query)
  174. if match:
  175. repo_name = match.group(1)
  176. branch_name = match.group(2)
  177. else:
  178. raise ValueError(
  179. f"Repository name and Branch name not found, instead found this \
  180. Repo: {repo_name}, Branch: {branch_name}"
  181. )
  182. data = self._get_github_repo_data(repo_name=repo_name, branch_name=branch_name)
  183. return data
  184. def _get_github_repo_file(self, query: str, type: str) -> list[dict]:
  185. """Get specific file content"""
  186. logging.info(f"Searching github repo for query: {query} is:{type}")
  187. pattern = r"repo:(\S+) path:(\S+)"
  188. match = re.search(pattern, query)
  189. if match:
  190. repo_name = match.group(1)
  191. file_path = match.group(2)
  192. else:
  193. raise ValueError(
  194. f"Repository name and File name not found, instead found this Repo: {repo_name}, File: {file_path}"
  195. )
  196. data = self._get_github_repo_data(repo_name=repo_name, file_path=file_path)
  197. return data
  198. def _search_github_data(self, search_type: str, query: str):
  199. """Search github data."""
  200. if search_type == "code":
  201. data = self._github_search_code(query)
  202. elif search_type == "repo":
  203. data = self._github_search_repo(query)
  204. elif search_type == "issue":
  205. data = self._github_search_issues_and_pr(query, search_type)
  206. elif search_type == "pr":
  207. data = self._github_search_issues_and_pr(query, search_type)
  208. elif search_type == "branch":
  209. data = self._get_github_repo_branch(query, search_type)
  210. elif search_type == "file":
  211. data = self._get_github_repo_file(query, search_type)
  212. elif search_type == "discussion":
  213. raise ValueError("GithubLoader does not support searching discussions yet.")
  214. else:
  215. raise NotImplementedError(f"{search_type} not supported")
  216. return data
  217. @staticmethod
  218. def _get_valid_github_query(query: str):
  219. """Check if query is valid and return search types and valid GitHub query."""
  220. query_terms = shlex.split(query)
  221. # query must provide repo to load data from
  222. if len(query_terms) < 1 or "repo:" not in query:
  223. raise ValueError(
  224. "GithubLoader requires a search query with `repo:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
  225. )
  226. github_query = []
  227. types = set()
  228. type_pattern = r"type:([a-zA-Z,]+)"
  229. for term in query_terms:
  230. term_match = re.search(type_pattern, term)
  231. if term_match:
  232. search_types = term_match.group(1).split(",")
  233. types.update(search_types)
  234. else:
  235. github_query.append(term)
  236. # query must provide search type
  237. if len(types) == 0:
  238. raise ValueError(
  239. "GithubLoader requires a search query with `type:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
  240. )
  241. for search_type in search_types:
  242. if search_type not in VALID_SEARCH_TYPES:
  243. raise ValueError(
  244. f"Invalid search type: {search_type}. Valid types are: {', '.join(VALID_SEARCH_TYPES)}"
  245. )
  246. query = " ".join(github_query)
  247. return types, query
  248. def load_data(self, search_query: str, max_results: int = 1000):
  249. """Load data from GitHub search query."""
  250. if not self.client:
  251. raise ValueError(
  252. "GithubLoader client is not initialized, data will not be loaded. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
  253. )
  254. search_types, query = self._get_valid_github_query(search_query)
  255. logging.info(f"Searching github for query: {query}, with types: {', '.join(search_types)}")
  256. data = []
  257. with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
  258. futures_map = executor.map(self._search_github_data, search_types, [query] * len(search_types))
  259. for search_data in tqdm(futures_map, total=len(search_types), desc="Searching data from github"):
  260. data.extend(search_data)
  261. return {
  262. "doc_id": hashlib.sha256(query.encode()).hexdigest(),
  263. "data": data,
  264. }