github.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import concurrent.futures
  2. import hashlib
  3. import logging
  4. import os
  5. import re
  6. import shlex
  7. from typing import Any, Dict, Optional
  8. from tqdm import tqdm
  9. from embedchain.loaders.base_loader import BaseLoader
  10. from embedchain.utils import clean_string
  11. GITHUB_URL = "https://github.com"
  12. GITHUB_API_URL = "https://api.github.com"
  13. VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
  14. class GithubLoader(BaseLoader):
  15. """Load data from github search query."""
  16. def __init__(self, config: Optional[Dict[str, Any]] = None):
  17. super().__init__()
  18. if not config:
  19. raise ValueError(
  20. "GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
  21. )
  22. try:
  23. from github import Github
  24. except ImportError as e:
  25. raise ValueError(
  26. "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`"
  27. ) from e
  28. self.config = config
  29. token = config.get("token")
  30. if not token:
  31. raise ValueError(
  32. "GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
  33. )
  34. try:
  35. self.client = Github(token)
  36. except Exception as e:
  37. logging.error(f"GithubLoader failed to initialize client: {e}")
  38. self.client = None
  39. def _github_search_code(self, query: str):
  40. """Search github code."""
  41. data = []
  42. results = self.client.search_code(query)
  43. for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
  44. url = result.html_url
  45. logging.info(f"Added data from url: {url}")
  46. content = result.decoded_content.decode("utf-8")
  47. metadata = {
  48. "url": url,
  49. }
  50. data.append(
  51. {
  52. "content": clean_string(content),
  53. "meta_data": metadata,
  54. }
  55. )
  56. return data
  57. def _get_github_repo_data(self, repo_url: str):
  58. local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
  59. local_path = f"/tmp/{local_hash}"
  60. data = []
  61. def _get_repo_tree(repo_url: str, local_path: str):
  62. try:
  63. from git import Repo
  64. except ImportError as e:
  65. raise ValueError(
  66. "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`" # noqa: E501
  67. ) from e
  68. if os.path.exists(local_path):
  69. logging.info("Repository already exists. Fetching updates...")
  70. repo = Repo(local_path)
  71. logging.info("Fetch completed.")
  72. else:
  73. logging.info("Cloning repository...")
  74. repo = Repo.clone_from(repo_url, local_path)
  75. logging.info("Clone completed.")
  76. return repo.head.commit.tree
  77. def _get_repo_tree_contents(repo_path, tree, progress_bar):
  78. for subtree in tree:
  79. if subtree.type == "tree":
  80. _get_repo_tree_contents(repo_path, subtree, progress_bar)
  81. else:
  82. assert subtree.type == "blob"
  83. try:
  84. contents = subtree.data_stream.read().decode("utf-8")
  85. except Exception:
  86. logging.warning(f"Failed to read file: {subtree.path}")
  87. progress_bar.update(1) if progress_bar else None
  88. continue
  89. url = f"{repo_url.rstrip('.git')}/blob/main/{subtree.path}"
  90. data.append(
  91. {
  92. "content": clean_string(contents),
  93. "meta_data": {
  94. "url": url,
  95. },
  96. }
  97. )
  98. if progress_bar is not None:
  99. progress_bar.update(1)
  100. repo_tree = _get_repo_tree(repo_url, local_path)
  101. tree_list = list(repo_tree.traverse())
  102. with tqdm(total=len(tree_list), desc="Loading files:", unit="item") as progress_bar:
  103. _get_repo_tree_contents(local_path, repo_tree, progress_bar)
  104. return data
  105. def _github_search_repo(self, query: str):
  106. """Search github repo."""
  107. data = []
  108. logging.info(f"Searching github repos with query: {query}")
  109. results = self.client.search_repositories(query)
  110. # Add repo urls and descriptions
  111. urls = list(map(lambda x: x.html_url, results))
  112. discriptions = list(map(lambda x: x.description, results))
  113. data.append(
  114. {
  115. "content": clean_string(desc),
  116. "meta_data": {
  117. "url": url,
  118. },
  119. }
  120. for url, desc in zip(urls, discriptions)
  121. )
  122. # Add repo contents
  123. for result in results:
  124. clone_url = result.clone_url
  125. logging.info(f"Cloning repository: {clone_url}")
  126. data = self._get_github_repo_data(clone_url)
  127. return data
  128. def _github_search_issues_and_pr(self, query: str, type: str):
  129. """Search github issues and PRs."""
  130. data = []
  131. query = f"{query} is:{type}"
  132. logging.info(f"Searching github for query: {query}")
  133. results = self.client.search_issues(query)
  134. logging.info(f"Total results: {results.totalCount}")
  135. for result in tqdm(results, total=results.totalCount, desc=f"Loading {type} from github"):
  136. url = result.html_url
  137. title = result.title
  138. body = result.body
  139. if not body:
  140. logging.warn(f"Skipping issue because empty content for: {url}")
  141. continue
  142. labels = " ".join([label.name for label in result.labels])
  143. issue_comments = result.get_comments()
  144. comments = []
  145. comments_created_at = []
  146. for comment in issue_comments:
  147. comments_created_at.append(str(comment.created_at))
  148. comments.append(f"{comment.user.name}:{comment.body}")
  149. content = "\n".join([title, labels, body, *comments])
  150. metadata = {
  151. "url": url,
  152. "created_at": str(result.created_at),
  153. "comments_created_at": " ".join(comments_created_at),
  154. }
  155. data.append(
  156. {
  157. "content": clean_string(content),
  158. "meta_data": metadata,
  159. }
  160. )
  161. return data
  162. # need to test more for discussion
  163. def _github_search_discussions(self, query: str):
  164. """Search github discussions."""
  165. data = []
  166. query = f"{query} is:discussion"
  167. logging.info(f"Searching github repo for query: {query}")
  168. repos_results = self.client.search_repositories(query)
  169. logging.info(f"Total repos found: {repos_results.totalCount}")
  170. for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
  171. teams = repo_result.get_teams()
  172. for team in teams:
  173. team_discussions = team.get_discussions()
  174. for discussion in team_discussions:
  175. url = discussion.html_url
  176. title = discussion.title
  177. body = discussion.body
  178. if not body:
  179. logging.warn(f"Skipping discussion because empty content for: {url}")
  180. continue
  181. comments = []
  182. comments_created_at = []
  183. print("Discussion comments: ", discussion.comments_url)
  184. content = "\n".join([title, body, *comments])
  185. metadata = {
  186. "url": url,
  187. "created_at": str(discussion.created_at),
  188. "comments_created_at": " ".join(comments_created_at),
  189. }
  190. data.append(
  191. {
  192. "content": clean_string(content),
  193. "meta_data": metadata,
  194. }
  195. )
  196. return data
  197. def _search_github_data(self, search_type: str, query: str):
  198. """Search github data."""
  199. if search_type == "code":
  200. data = self._github_search_code(query)
  201. elif search_type == "repo":
  202. data = self._github_search_repo(query)
  203. elif search_type == "issue":
  204. data = self._github_search_issues_and_pr(query, search_type)
  205. elif search_type == "pr":
  206. data = self._github_search_issues_and_pr(query, search_type)
  207. elif search_type == "discussion":
  208. raise ValueError("GithubLoader does not support searching discussions yet.")
  209. return data
  210. def _get_valid_github_query(self, query: str):
  211. """Check if query is valid and return search types and valid github query."""
  212. query_terms = shlex.split(query)
  213. # query must provide repo to load data from
  214. if len(query_terms) < 1 or "repo:" not in query:
  215. raise ValueError(
  216. "GithubLoader requires a search query with `repo:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
  217. )
  218. github_query = []
  219. types = set()
  220. type_pattern = r"type:([a-zA-Z,]+)"
  221. for term in query_terms:
  222. term_match = re.search(type_pattern, term)
  223. if term_match:
  224. search_types = term_match.group(1).split(",")
  225. types.update(search_types)
  226. else:
  227. github_query.append(term)
  228. # query must provide search type
  229. if len(types) == 0:
  230. raise ValueError(
  231. "GithubLoader requires a search query with `type:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
  232. )
  233. for search_type in search_types:
  234. if search_type not in VALID_SEARCH_TYPES:
  235. raise ValueError(
  236. f"Invalid search type: {search_type}. Valid types are: {', '.join(VALID_SEARCH_TYPES)}"
  237. )
  238. query = " ".join(github_query)
  239. return types, query
  240. def load_data(self, search_query: str, max_results: int = 1000):
  241. """Load data from github search query."""
  242. if not self.client:
  243. raise ValueError(
  244. "GithubLoader client is not initialized, data will not be loaded. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
  245. )
  246. search_types, query = self._get_valid_github_query(search_query)
  247. logging.info(f"Searching github for query: {query}, with types: {', '.join(search_types)}")
  248. data = []
  249. with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
  250. futures_map = executor.map(self._search_github_data, search_types, [query] * len(search_types))
  251. for search_data in tqdm(futures_map, total=len(search_types), desc="Searching data from github"):
  252. data.extend(search_data)
  253. return {
  254. "doc_id": hashlib.sha256(query.encode()).hexdigest(),
  255. "data": data,
  256. }