github.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import concurrent.futures
  2. import hashlib
  3. import logging
  4. import os
  5. import re
  6. import shlex
  7. from typing import Any, Optional
  8. from tqdm import tqdm
  9. from embedchain.loaders.base_loader import BaseLoader
  10. from embedchain.utils.misc import clean_string
  11. GITHUB_URL = "https://github.com"
  12. GITHUB_API_URL = "https://api.github.com"
  13. VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
  14. class GithubLoader(BaseLoader):
  15. """Load data from GitHub search query."""
  16. def __init__(self, config: Optional[dict[str, Any]] = None):
  17. super().__init__()
  18. if not config:
  19. raise ValueError(
  20. "GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
  21. )
  22. try:
  23. from github import Github
  24. except ImportError as e:
  25. raise ValueError(
  26. "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`"
  27. ) from e
  28. self.config = config
  29. token = config.get("token")
  30. if not token:
  31. raise ValueError(
  32. "GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
  33. )
  34. try:
  35. self.client = Github(token)
  36. except Exception as e:
  37. logging.error(f"GithubLoader failed to initialize client: {e}")
  38. self.client = None
  39. def _github_search_code(self, query: str):
  40. """Search GitHub code."""
  41. data = []
  42. results = self.client.search_code(query)
  43. for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
  44. url = result.html_url
  45. logging.info(f"Added data from url: {url}")
  46. content = result.decoded_content.decode("utf-8")
  47. metadata = {
  48. "url": url,
  49. }
  50. data.append(
  51. {
  52. "content": clean_string(content),
  53. "meta_data": metadata,
  54. }
  55. )
  56. return data
  57. @staticmethod
  58. def _get_github_repo_data(repo_url: str):
  59. local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
  60. local_path = f"/tmp/{local_hash}"
  61. data = []
  62. def _get_repo_tree(repo_url: str, local_path: str):
  63. try:
  64. from git import Repo
  65. except ImportError as e:
  66. raise ValueError(
  67. "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`" # noqa: E501
  68. ) from e
  69. if os.path.exists(local_path):
  70. logging.info("Repository already exists. Fetching updates...")
  71. repo = Repo(local_path)
  72. logging.info("Fetch completed.")
  73. else:
  74. logging.info("Cloning repository...")
  75. repo = Repo.clone_from(repo_url, local_path)
  76. logging.info("Clone completed.")
  77. return repo.head.commit.tree
  78. def _get_repo_tree_contents(repo_path, tree, progress_bar):
  79. for subtree in tree:
  80. if subtree.type == "tree":
  81. _get_repo_tree_contents(repo_path, subtree, progress_bar)
  82. else:
  83. assert subtree.type == "blob"
  84. try:
  85. contents = subtree.data_stream.read().decode("utf-8")
  86. except Exception:
  87. logging.warning(f"Failed to read file: {subtree.path}")
  88. progress_bar.update(1) if progress_bar else None
  89. continue
  90. url = f"{repo_url.rstrip('.git')}/blob/main/{subtree.path}"
  91. data.append(
  92. {
  93. "content": clean_string(contents),
  94. "meta_data": {
  95. "url": url,
  96. },
  97. }
  98. )
  99. if progress_bar is not None:
  100. progress_bar.update(1)
  101. repo_tree = _get_repo_tree(repo_url, local_path)
  102. tree_list = list(repo_tree.traverse())
  103. with tqdm(total=len(tree_list), desc="Loading files:", unit="item") as progress_bar:
  104. _get_repo_tree_contents(local_path, repo_tree, progress_bar)
  105. return data
  106. def _github_search_repo(self, query: str) -> list[dict]:
  107. """Search GitHub repo."""
  108. data = []
  109. logging.info(f"Searching github repos with query: {query}")
  110. results = self.client.search_repositories(query)
  111. # Add repo urls and descriptions
  112. urls = list(map(lambda x: x.html_url, results))
  113. descriptions = list(map(lambda x: x.description, results))
  114. data.append(
  115. {
  116. "content": clean_string(desc),
  117. "meta_data": {
  118. "url": url,
  119. },
  120. }
  121. for url, desc in zip(urls, descriptions)
  122. )
  123. # Add repo contents
  124. for result in results:
  125. clone_url = result.clone_url
  126. logging.info(f"Cloning repository: {clone_url}")
  127. data = self._get_github_repo_data(clone_url)
  128. return data
  129. def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
  130. """Search GitHub issues and PRs."""
  131. data = []
  132. query = f"{query} is:{type}"
  133. logging.info(f"Searching github for query: {query}")
  134. results = self.client.search_issues(query)
  135. logging.info(f"Total results: {results.totalCount}")
  136. for result in tqdm(results, total=results.totalCount, desc=f"Loading {type} from github"):
  137. url = result.html_url
  138. title = result.title
  139. body = result.body
  140. if not body:
  141. logging.warning(f"Skipping issue because empty content for: {url}")
  142. continue
  143. labels = " ".join([label.name for label in result.labels])
  144. issue_comments = result.get_comments()
  145. comments = []
  146. comments_created_at = []
  147. for comment in issue_comments:
  148. comments_created_at.append(str(comment.created_at))
  149. comments.append(f"{comment.user.name}:{comment.body}")
  150. content = "\n".join([title, labels, body, *comments])
  151. metadata = {
  152. "url": url,
  153. "created_at": str(result.created_at),
  154. "comments_created_at": " ".join(comments_created_at),
  155. }
  156. data.append(
  157. {
  158. "content": clean_string(content),
  159. "meta_data": metadata,
  160. }
  161. )
  162. return data
  163. # need to test more for discussion
  164. def _github_search_discussions(self, query: str):
  165. """Search GitHub discussions."""
  166. data = []
  167. query = f"{query} is:discussion"
  168. logging.info(f"Searching github repo for query: {query}")
  169. repos_results = self.client.search_repositories(query)
  170. logging.info(f"Total repos found: {repos_results.totalCount}")
  171. for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
  172. teams = repo_result.get_teams()
  173. for team in teams:
  174. team_discussions = team.get_discussions()
  175. for discussion in team_discussions:
  176. url = discussion.html_url
  177. title = discussion.title
  178. body = discussion.body
  179. if not body:
  180. logging.warning(f"Skipping discussion because empty content for: {url}")
  181. continue
  182. comments = []
  183. comments_created_at = []
  184. print("Discussion comments: ", discussion.comments_url)
  185. content = "\n".join([title, body, *comments])
  186. metadata = {
  187. "url": url,
  188. "created_at": str(discussion.created_at),
  189. "comments_created_at": " ".join(comments_created_at),
  190. }
  191. data.append(
  192. {
  193. "content": clean_string(content),
  194. "meta_data": metadata,
  195. }
  196. )
  197. return data
  198. def _search_github_data(self, search_type: str, query: str):
  199. """Search github data."""
  200. if search_type == "code":
  201. data = self._github_search_code(query)
  202. elif search_type == "repo":
  203. data = self._github_search_repo(query)
  204. elif search_type == "issue":
  205. data = self._github_search_issues_and_pr(query, search_type)
  206. elif search_type == "pr":
  207. data = self._github_search_issues_and_pr(query, search_type)
  208. elif search_type == "discussion":
  209. raise ValueError("GithubLoader does not support searching discussions yet.")
  210. else:
  211. raise NotImplementedError(f"{search_type} not supported")
  212. return data
  213. @staticmethod
  214. def _get_valid_github_query(query: str):
  215. """Check if query is valid and return search types and valid GitHub query."""
  216. query_terms = shlex.split(query)
  217. # query must provide repo to load data from
  218. if len(query_terms) < 1 or "repo:" not in query:
  219. raise ValueError(
  220. "GithubLoader requires a search query with `repo:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
  221. )
  222. github_query = []
  223. types = set()
  224. type_pattern = r"type:([a-zA-Z,]+)"
  225. for term in query_terms:
  226. term_match = re.search(type_pattern, term)
  227. if term_match:
  228. search_types = term_match.group(1).split(",")
  229. types.update(search_types)
  230. else:
  231. github_query.append(term)
  232. # query must provide search type
  233. if len(types) == 0:
  234. raise ValueError(
  235. "GithubLoader requires a search query with `type:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
  236. )
  237. for search_type in search_types:
  238. if search_type not in VALID_SEARCH_TYPES:
  239. raise ValueError(
  240. f"Invalid search type: {search_type}. Valid types are: {', '.join(VALID_SEARCH_TYPES)}"
  241. )
  242. query = " ".join(github_query)
  243. return types, query
  244. def load_data(self, search_query: str, max_results: int = 1000):
  245. """Load data from GitHub search query."""
  246. if not self.client:
  247. raise ValueError(
  248. "GithubLoader client is not initialized, data will not be loaded. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
  249. )
  250. search_types, query = self._get_valid_github_query(search_query)
  251. logging.info(f"Searching github for query: {query}, with types: {', '.join(search_types)}")
  252. data = []
  253. with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
  254. futures_map = executor.map(self._search_github_data, search_types, [query] * len(search_types))
  255. for search_data in tqdm(futures_map, total=len(search_types), desc="Searching data from github"):
  256. data.extend(search_data)
  257. return {
  258. "doc_id": hashlib.sha256(query.encode()).hexdigest(),
  259. "data": data,
  260. }