Преглед на файлове

[Feature] Google Drive Folder support as a data source (#1106)

Joe Sleiman преди 1 година
родител
ревизия
b4ec14382b

+ 28 - 0
docs/components/data-sources/google-drive.mdx

@@ -0,0 +1,28 @@
+---
+title: 'Google Drive'
+---
+
+To use GoogleDriveLoader you must install the extra dependencies with `pip install --upgrade embedchain[googledrive]`.
+
+The data_type must be `google_drive`. Otherwise, it will be considered a regular web page.
+
+Google Drive requires the setup of credentials. This can be done by following the steps below:
+
+1. Go to the [Google Cloud Console](https://console.cloud.google.com/apis/credentials).
+2. Create a project if you don't have one already.
+3. Enable the [Google Drive API](https://console.cloud.google.com/flows/enableapi?apiid=drive.googleapis.com)
+4. [Authorize credentials for desktop app](https://developers.google.com/drive/api/quickstart/python#authorize_credentials_for_a_desktop_application)
+5. When done, you will be able to download the credentials in `json` format. Rename the downloaded file to `credentials.json` and save it in `~/.credentials/credentials.json`
+6. Set the environment variable `GOOGLE_APPLICATION_CREDENTIALS=~/.credentials/credentials.json`
+
+The first time you use the loader, you will be prompted to enter your Google account credentials.
+
+
+```python
+from embedchain import Pipeline as App
+
+app = App()
+
+url = "https://drive.google.com/drive/u/0/folders/xxx-xxx"
+app.add(url, data_type="google_drive")
+```

+ 22 - 0
embedchain/chunkers/google_drive.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helpers.json_serializable import register_deserializable
+
+
+@register_deserializable
+class GoogleDriveChunker(BaseChunker):
+    """Chunker for google drive folder."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 2 - 0
embedchain/data_formatter/data_formatter.py

@@ -74,6 +74,7 @@ class DataFormatter(JSONSerializable):
             DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
             DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
             DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
+            DataType.GOOGLE_DRIVE: "embedchain.loaders.google_drive.GoogleDriveLoader",
             DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
             DataType.SLACK: "embedchain.loaders.slack.SlackLoader",
             DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
@@ -120,6 +121,7 @@ class DataFormatter(JSONSerializable):
             DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
             DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
+            DataType.GOOGLE_DRIVE: "embedchain.chunkers.google_drive.GoogleDriveChunker",
             DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.SLACK: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",

+ 54 - 0
embedchain/loaders/google_drive.py

@@ -0,0 +1,54 @@
+import hashlib
+import re
+
+try:
+    from googleapiclient.errors import HttpError
+except ImportError:
+    raise ImportError(
+        "Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`"
+    ) from None
+
+from langchain.document_loaders import GoogleDriveLoader as Loader, UnstructuredFileIOLoader
+
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.loaders.base_loader import BaseLoader
+
+
+@register_deserializable
+class GoogleDriveLoader(BaseLoader):
+    @staticmethod
+    def _get_drive_id_from_url(url: str):
+        regex = r"^https:\/\/drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
+        if re.match(regex, url):
+            return url.split("/")[-1]
+        raise ValueError(
+            f"The url provided {url} does not match a google drive folder url. Example drive url: "
+            f"https://drive.google.com/drive/u/0/folders/xxxx"
+        )
+
+    def load_data(self, url: str):
+        """Load data from a Google drive folder."""
+        folder_id: str = self._get_drive_id_from_url(url)
+
+        try:
+            loader = Loader(
+                folder_id=folder_id,
+                recursive=True,
+                file_loader_cls=UnstructuredFileIOLoader,
+            )
+
+            data = []
+            all_content = []
+
+            docs = loader.load()
+            for doc in docs:
+                all_content.append(doc.page_content)
+                # renames source to url for later use.
+                doc.metadata["url"] = doc.metadata.pop("source")
+                data.append({"content": doc.page_content, "meta_data": doc.metadata})
+
+            doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
+            return {"doc_id": doc_id, "data": data}
+
+        except HttpError:
+            raise FileNotFoundError("Unable to locate folder or files, check provided drive URL and try again")

+ 2 - 0
embedchain/models/data_type.py

@@ -35,6 +35,7 @@ class IndirectDataType(Enum):
     CUSTOM = "custom"
     RSSFEED = "rss_feed"
     BEEHIIV = "beehiiv"
+    GOOGLE_DRIVE = "google_drive"
     DIRECTORY = "directory"
     SLACK = "slack"
     DROPBOX = "dropbox"
@@ -73,6 +74,7 @@ class DataType(Enum):
     CUSTOM = IndirectDataType.CUSTOM.value
     RSSFEED = IndirectDataType.RSSFEED.value
     BEEHIIV = IndirectDataType.BEEHIIV.value
+    GOOGLE_DRIVE = IndirectDataType.GOOGLE_DRIVE.value
     DIRECTORY = IndirectDataType.DIRECTORY.value
     SLACK = IndirectDataType.SLACK.value
     DROPBOX = IndirectDataType.DROPBOX.value

+ 10 - 2
embedchain/utils.py

@@ -183,6 +183,11 @@ def detect_datatype(source: Any) -> DataType:
         # currently the following two fields are required in openapi spec yaml config
         return "openapi" in yaml_content and "info" in yaml_content
 
+    def is_google_drive_folder(url):
+        # checks if url is a Google Drive folder url against a regex
+        regex = r"^drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
+        return re.match(regex, url)
+
     try:
         if not isinstance(source, str):
             raise ValueError("Source is not a string and thus cannot be a URL.")
@@ -196,8 +201,7 @@ def detect_datatype(source: Any) -> DataType:
     formatted_source = format_source(str(source), 30)
 
     if url:
-        from langchain.document_loaders.youtube import \
-            ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
+        from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -266,6 +270,10 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `github`.")
             return DataType.GITHUB
 
+        if is_google_drive_folder(url.netloc + url.path):
+            logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
+            return DataType.GOOGLE_DRIVE_FOLDER
+
         # If none of the above conditions are met, it's a general web page
         logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
         return DataType.WEB_PAGE

+ 27 - 0
poetry.lock

@@ -4677,6 +4677,32 @@ files = [
     {file = "protobuf-4.21.12.tar.gz", hash = "sha256:7cd532c4566d0e6feafecc1059d04c7915aec8e182d1cf7adee8b24ef1e2e6ab"},
 ]
 
+[[package]]
+name = "psutil"
+version = "5.9.5"
+description = "Cross-platform lib for process and system monitoring in Python."
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+    {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"},
+    {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"},
+    {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"},
+    {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"},
+    {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"},
+    {file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"},
+    {file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"},
+    {file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"},
+    {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"},
+    {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"},
+    {file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"},
+    {file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"},
+    {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"},
+    {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"},
+]
+
+[package.extras]
+test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
+
 [[package]]
 name = "psycopg"
 version = "3.1.12"
@@ -8095,6 +8121,7 @@ elasticsearch = ["elasticsearch"]
 github = ["PyGithub", "gitpython"]
 gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-auth-httplib2", "google-auth-oauthlib", "requests"]
 google = ["google-generativeai"]
+googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"]
 huggingface-hub = ["huggingface_hub"]
 llama2 = ["replicate"]
 milvus = ["pymilvus"]

+ 1 - 0
pyproject.toml

@@ -197,6 +197,7 @@ gmail = [
     "google-auth-httplib2",
     "google-api-core",
 ]
+googledrive = ["google-api-python-client", "google-auth-oauthlib", "google-auth-httplib2"]
 postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
 mysql = ["mysql-connector-python"]
 github = ["PyGithub", "gitpython"]

+ 2 - 0
tests/chunkers/test_chunkers.py

@@ -3,6 +3,7 @@ from embedchain.chunkers.discourse import DiscourseChunker
 from embedchain.chunkers.docs_site import DocsSiteChunker
 from embedchain.chunkers.docx_file import DocxFileChunker
 from embedchain.chunkers.gmail import GmailChunker
+from embedchain.chunkers.google_drive import GoogleDriveChunker
 from embedchain.chunkers.json import JSONChunker
 from embedchain.chunkers.mdx import MdxChunker
 from embedchain.chunkers.notion import NotionChunker
@@ -41,6 +42,7 @@ chunker_common_config = {
     SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
+    GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
 }
 
 

+ 37 - 0
tests/loaders/test_google_drive.py

@@ -0,0 +1,37 @@
+import pytest
+
+from embedchain.loaders.google_drive import GoogleDriveLoader
+
+
+@pytest.fixture
+def google_drive_folder_loader():
+    return GoogleDriveLoader()
+
+
+def test_load_data_invalid_drive_url(google_drive_folder_loader):
+    mock_invalid_drive_url = "https://example.com"
+    with pytest.raises(
+        ValueError,
+        match="The url provided https://example.com does not match a google drive folder url. Example "
+        "drive url: https://drive.google.com/drive/u/0/folders/xxxx",
+    ):
+        google_drive_folder_loader.load_data(mock_invalid_drive_url)
+
+
+@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
+def test_load_data_incorrect_drive_url(google_drive_folder_loader):
+    mock_invalid_drive_url = "https://drive.google.com/drive/u/0/folders/xxxx"
+    with pytest.raises(
+        FileNotFoundError, match="Unable to locate folder or files, check provided drive URL and try again"
+    ):
+        google_drive_folder_loader.load_data(mock_invalid_drive_url)
+
+
+@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
+def test_load_data(google_drive_folder_loader):
+    mock_valid_url = "YOUR_VALID_URL"
+    result = google_drive_folder_loader.load_data(mock_valid_url)
+    assert "doc_id" in result
+    assert "data" in result
+    assert "content" in result["data"][0]
+    assert "meta_data" in result["data"][0]