Bläddra i källkod

Add folder and branch to GitHub (#1308)

Dev Khant 1 år sedan
förälder
incheckning
117824b32c
2 ändrade filer med 74 tillägg och 61 borttagningar
  1. 3 1
      docs/components/data-sources/github.mdx
  2. 71 60
      embedchain/loaders/github.py

+ 3 - 1
docs/components/data-sources/github.mdx

@@ -29,11 +29,13 @@ response = app.query("What is Embedchain?")
 ```
 The `add` function of the app will accept any valid github query with qualifiers. It only supports loading github code, repository, issues and pull-requests.
 <Note>
-You must provide qualifiers `type:` and `repo:` in the query. The `type:` qualifier can be a combination of `code`, `repo`, `pr`, `issue`. The `repo:` qualifier must be a valid github repository name.
+You must provide qualifiers `type:` and `repo:` in the query. The `type:` qualifier can be a combination of `code`, `repo`, `pr`, `issue`, `branch`, `file`. The `repo:` qualifier must be a valid github repository name.
 </Note>
 
 <Card title="Valid queries" icon="lightbulb" iconType="duotone" color="#ca8b04">
     - `repo:embedchain/embedchain type:repo` - to load the repository
+    - `repo:embedchain/embedchain type:branch name:feature_test` - to load the branch of the repository
+    - `repo:embedchain/embedchain type:file path:README.md` - to load the specific file of the repository
     - `repo:embedchain/embedchain type:issue,pr` - to load the issues and pull-requests of the repository
     - `repo:embedchain/embedchain type:issue state:closed` - to load the closed issues of the repository
 </Card>

+ 71 - 60
embedchain/loaders/github.py

@@ -1,7 +1,6 @@
 import concurrent.futures
 import hashlib
 import logging
-import os
 import re
 import shlex
 from typing import Any, Optional
@@ -14,7 +13,7 @@ from embedchain.utils.misc import clean_string
 GITHUB_URL = "https://github.com"
 GITHUB_API_URL = "https://api.github.com"
 
-VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
+VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion", "branch", "file"])
 
 
 class GithubLoader(BaseLoader):
@@ -66,85 +65,56 @@ class GithubLoader(BaseLoader):
             )
         return data
 
-    @staticmethod
-    def _get_github_repo_data(repo_url: str):
-        local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
-        local_path = f"/tmp/{local_hash}"
+    def _get_github_repo_data(self, repo_name: str, branch_name: str = None, file_path: str = None) -> list[dict]:
+        """Get file contents from Repo"""
         data = []
 
-        def _get_repo_tree(repo_url: str, local_path: str):
-            try:
-                from git import Repo
-            except ImportError as e:
-                raise ValueError(
-                    "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`"  # noqa: E501
-                ) from e
+        repo = self.client.get_repo(repo_name)
+        repo_contents = repo.get_contents("")
 
-            if os.path.exists(local_path):
-                logging.info("Repository already exists. Fetching updates...")
-                repo = Repo(local_path)
-                logging.info("Fetch completed.")
-            else:
-                logging.info("Cloning repository...")
-                repo = Repo.clone_from(repo_url, local_path)
-                logging.info("Clone completed.")
-            return repo.head.commit.tree
-
-        def _get_repo_tree_contents(repo_path, tree, progress_bar):
-            for subtree in tree:
-                if subtree.type == "tree":
-                    _get_repo_tree_contents(repo_path, subtree, progress_bar)
+        if branch_name:
+            repo_contents = repo.get_contents("", ref=branch_name)
+        if file_path:
+            repo_contents = [repo.get_contents(file_path)]
+
+        with tqdm(desc="Loading files:", unit="item") as progress_bar:
+            while repo_contents:
+                file_content = repo_contents.pop(0)
+                if file_content.type == "dir":
+                    try:
+                        repo_contents.extend(repo.get_contents(file_content.path))
+                    except Exception:
+                        logging.warning(f"Failed to read directory: {file_content.path}")
+                        progress_bar.update(1)
+                        continue
                 else:
-                    assert subtree.type == "blob"
                     try:
-                        contents = subtree.data_stream.read().decode("utf-8")
+                        file_text = file_content.decoded_content.decode()
                     except Exception:
-                        logging.warning(f"Failed to read file: {subtree.path}")
-                        progress_bar.update(1) if progress_bar else None
+                        logging.warning(f"Failed to read file: {file_content.path}")
+                        progress_bar.update(1)
                         continue
 
-                    url = f"{repo_url.rstrip('.git')}/blob/main/{subtree.path}"
+                    file_path = file_content.path
                     data.append(
                         {
-                            "content": clean_string(contents),
+                            "content": clean_string(file_text),
                             "meta_data": {
-                                "url": url,
+                                "path": file_path,
                             },
                         }
                     )
-                if progress_bar is not None:
-                    progress_bar.update(1)
 
-        repo_tree = _get_repo_tree(repo_url, local_path)
-        tree_list = list(repo_tree.traverse())
-        with tqdm(total=len(tree_list), desc="Loading files:", unit="item") as progress_bar:
-            _get_repo_tree_contents(local_path, repo_tree, progress_bar)
+                progress_bar.update(1)
 
         return data
 
     def _github_search_repo(self, query: str) -> list[dict]:
         """Search GitHub repo."""
-        data = []
+
         logging.info(f"Searching github repos with query: {query}")
-        results = self.client.search_repositories(query)
-        # Add repo urls and descriptions
-        urls = list(map(lambda x: x.html_url, results))
-        descriptions = list(map(lambda x: x.description, results))
-        data.append(
-            {
-                "content": clean_string(desc),
-                "meta_data": {
-                    "url": url,
-                },
-            }
-            for url, desc in zip(urls, descriptions)
-        )
-
-        # Add repo contents
-        for result in results:
-            clone_url = result.clone_url
-            logging.info(f"Cloning repository: {clone_url}")
-            data = self._get_github_repo_data(clone_url)
+        updated_query = query.split(":")[-1]
+        data = self._get_github_repo_data(updated_query)
         return data
 
     def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
@@ -222,6 +192,43 @@ class GithubLoader(BaseLoader):
                     )
         return data
 
+    def _get_github_repo_branch(self, query: str, type: str) -> list[dict]:
+        """Get file contents for specific branch"""
+
+        logging.info(f"Searching github repo for query: {query} is:{type}")
+        pattern = r"repo:(\S+) name:(\S+)"
+        match = re.search(pattern, query)
+
+        if match:
+            repo_name = match.group(1)
+            branch_name = match.group(2)
+        else:
+            raise ValueError(
+                f"Repository name and Branch name not found, instead found this \
+                    Repo: {repo_name}, Branch: {branch_name}"
+            )
+
+        data = self._get_github_repo_data(repo_name=repo_name, branch_name=branch_name)
+        return data
+
+    def _get_github_repo_file(self, query: str, type: str) -> list[dict]:
+        """Get specific file content"""
+
+        logging.info(f"Searching github repo for query: {query} is:{type}")
+        pattern = r"repo:(\S+) path:(\S+)"
+        match = re.search(pattern, query)
+
+        if match:
+            repo_name = match.group(1)
+            file_path = match.group(2)
+        else:
+            raise ValueError(
+                f"Repository name and File name not found, instead found this Repo: {repo_name}, File: {file_path}"
+            )
+
+        data = self._get_github_repo_data(repo_name=repo_name, file_path=file_path)
+        return data
+
     def _search_github_data(self, search_type: str, query: str):
         """Search github data."""
         if search_type == "code":
@@ -232,6 +239,10 @@ class GithubLoader(BaseLoader):
             data = self._github_search_issues_and_pr(query, search_type)
         elif search_type == "pr":
             data = self._github_search_issues_and_pr(query, search_type)
+        elif search_type == "branch":
+            data = self._get_github_repo_branch(query, search_type)
+        elif search_type == "file":
+            data = self._get_github_repo_file(query, search_type)
         elif search_type == "discussion":
             raise ValueError("GithubLoader does not support searching discussions yet.")
         else: