瀏覽代碼

[Feature] Add Slack Loader (#932)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 年之前
父節點
當前提交
539286aafd

+ 1 - 0
docs/data-sources/overview.mdx

@@ -21,6 +21,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
   <Card title="🎥📺 youtube video" href="/data-sources/youtube-video"></Card>
   <Card title="📬 Gmail" href="/data-sources/gmail"></Card>
   <Card title="🐘 Postgres" href="/data-sources/postgres"></Card>
+  <Card title="🤖 Slack" href="/data-sources/slack"></Card>
 </CardGroup>
 
 <br/ >

+ 54 - 0
docs/data-sources/slack.mdx

@@ -0,0 +1,54 @@
+---
+title: '🤖 Slack'
+---
+
+## Pre-requisite
+- Download required packages by running `pip install --upgrade "embedchain[slack]"`.
+- Configure your slack bot token as environment variable `SLACK_USER_TOKEN`.
+    - Find your user token on your [Slack Account](https://api.slack.com/authentication/token-types)
+    - Make sure your slack user token includes [search](https://api.slack.com/scopes/search:read) scope.
+
+## Example
+1. Setup the Slack loader by configuring the Slack Webclient.
+```Python
+from embedchain.loaders.slack import SlackLoader
+
+os.environ["SLACK_USER_TOKEN"] = "xoxp-*"
+
+loader = SlackLoader()
+
+"""
+config = {
+    'base_url': slack_app_url,
+    'headers': web_headers,
+    'team_id': slack_team_id,
+}
+
+loader = SlackLoader(config)
+"""
+```
+
+NOTE: you can also pass the `config` with `base_url`, `headers`, `team_id` to setup your SlackLoader.
+
+2. Once you setup the loader, you can create an app and load data using the above slack loader
+```Python
+import os
+from embedchain.pipeline import Pipeline as App
+
+app = App()
+
+app.add("in:random", data_type="slack", loader=loader)
+question = "Which bots are available in the slack workspace's random channel?"
+# Answer: The available bot in the slack workspace's random channel is the Embedchain bot.
+```
+
+3. We automatically create a chunker to chunk your slack data, however if you wish to provide your own chunker class. Here is how you can do that:
+```Python
+from embedchain.chunkers.slack import SlackChunker
+from embedchain.config.add_config import ChunkerConfig
+
+slack_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+slack_chunker = SlackChunker(config=slack_chunker_config)
+
+app.add(slack_chunker, data_type="slack", loader=loader, chunker=slack_chunker)
+```

+ 22 - 0
embedchain/chunkers/slack.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class SlackChunker(BaseChunker):
+    """Chunker for postgres."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 2 - 0
embedchain/data_formatter/data_formatter.py

@@ -68,6 +68,7 @@ class DataFormatter(JSONSerializable):
         custom_loaders = set(
             [
                 DataType.POSTGRES,
+                DataType.SLACK,
             ]
         )
 
@@ -106,6 +107,7 @@ class DataFormatter(JSONSerializable):
             DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
             DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
             DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
+            DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
         }
 
         if data_type in chunker_classes:

+ 1 - 3
embedchain/loaders/postgres.py

@@ -40,9 +40,7 @@ class PostgresLoader(BaseLoader):
     def _check_query(self, query):
         if not isinstance(query, str):
             raise ValueError(
-                f"Invalid postgres query: {query}",
-                "Provide the valid source to add from postgres, \
-                    make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",
+                f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",  # noqa:E501
             )
 
     def load_data(self, query):

+ 108 - 0
embedchain/loaders/slack.py

@@ -0,0 +1,108 @@
+import hashlib
+import logging
+import os
+import ssl
+from typing import Any, Dict, Optional
+
+import certifi
+
+from embedchain.loaders.base_loader import BaseLoader
+from embedchain.utils import clean_string
+
+SLACK_API_BASE_URL = "https://www.slack.com/api/"
+
+
+class SlackLoader(BaseLoader):
+    def __init__(self, config: Optional[Dict[str, Any]] = None):
+        super().__init__()
+
+        if config is not None:
+            self.config = config
+        else:
+            self.config = {"base_url": SLACK_API_BASE_URL}
+
+        self.client = None
+        self._setup_loader(self.config)
+
+    def _setup_loader(self, config: Dict[str, Any]):
+        try:
+            from slack_sdk import WebClient
+        except ImportError as e:
+            raise ImportError(
+                "Slack loader requires extra dependencies. \
+                Install with `pip install --upgrade embedchain[slack]`"
+            ) from e
+
+        if os.getenv("SLACK_USER_TOKEN") is None:
+            raise ValueError(
+                "SLACK_USER_TOKEN environment variables not provided. Check `https://docs.embedchain.ai/data-sources/slack` to learn more."  # noqa:E501
+            )
+
+        logging.info(f"Creating Slack Loader with config: {config}")
+        # get slack client config params
+        slack_bot_token = os.getenv("SLACK_USER_TOKEN")
+        ssl_cert = ssl.create_default_context(cafile=certifi.where())
+        base_url = config.get("base_url", SLACK_API_BASE_URL)
+        headers = config.get("headers")
+        # for Org-Wide App
+        team_id = config.get("team_id")
+
+        self.client = WebClient(
+            token=slack_bot_token,
+            base_url=base_url,
+            ssl=ssl_cert,
+            headers=headers,
+            team_id=team_id,
+        )
+        logging.info("Slack Loader setup successful!")
+
+    def _check_query(self, query):
+        if not isinstance(query, str):
+            raise ValueError(
+                f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more."  # noqa:E501
+            )
+
+    def load_data(self, query):
+        self._check_query(query)
+        try:
+            data = []
+            data_content = []
+
+            logging.info(f"Searching slack conversations for query: {query}")
+            results = self.client.search_messages(
+                query=query,
+                sort="timestamp",
+                sort_dir="desc",
+                count=1000,
+            )
+
+            messages = results.get("messages")
+            num_message = results.get("total")
+            logging.info(f"Found {num_message} messages for query: {query}")
+
+            matches = messages.get("matches", [])
+            for message in matches:
+                url = message.get("permalink")
+                text = message.get("text")
+                content = clean_string(text)
+
+                message_meta_data_keys = ["channel", "iid", "team", "ts", "type", "user", "username"]
+                meta_data = message.fromkeys(message_meta_data_keys, "")
+                meta_data.update({"url": url})
+                data.append(
+                    {
+                        "content": content,
+                        "meta_data": meta_data,
+                    }
+                )
+                data_content.append(content)
+            doc_id = hashlib.md5((query + ", ".join(data_content)).encode()).hexdigest()
+            return {
+                "doc_id": doc_id,
+                "data": data,
+            }
+        except Exception as e:
+            logging.warning(f"Error in loading slack data: {e}")
+            raise ValueError(
+                f"Error in loading slack data: {e}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more."  # noqa:E501
+            ) from e

+ 2 - 0
embedchain/models/data_type.py

@@ -30,6 +30,7 @@ class IndirectDataType(Enum):
     OPENAPI = "openapi"
     GMAIL = "gmail"
     POSTGRES = "postgres"
+    SLACK = "slack"
 
 
 class SpecialDataType(Enum):
@@ -59,3 +60,4 @@ class DataType(Enum):
     OPENAPI = IndirectDataType.OPENAPI.value
     GMAIL = IndirectDataType.GMAIL.value
     POSTGRES = IndirectDataType.POSTGRES.value
+    SLACK = IndirectDataType.SLACK.value

+ 9 - 9
examples/rest-api/main.py

@@ -83,7 +83,7 @@ async def create_app_using_default_config(app_id: str, config: UploadFile = None
 
         return DefaultResponse(response=f"App created successfully. App ID: {app_id}")
     except Exception as e:
-        logging.warn(str(e))
+        logging.warning(str(e))
         raise HTTPException(detail=f"Error creating app: {str(e)}", status_code=400)
 
 
@@ -113,13 +113,13 @@ async def get_datasources_associated_with_app_id(app_id: str, db: Session = Depe
         response = app.get_data_sources()
         return {"results": response}
     except ValueError as ve:
-        logging.warn(str(ve))
+        logging.warning(str(ve))
         raise HTTPException(
             detail=generate_error_message_for_api_keys(ve),
             status_code=400,
         )
     except Exception as e:
-        logging.warn(str(e))
+        logging.warning(str(e))
         raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
 
 
@@ -152,13 +152,13 @@ async def add_datasource_to_an_app(body: SourceApp, app_id: str, db: Session = D
         response = app.add(source=body.source, data_type=body.data_type)
         return DefaultResponse(response=response)
     except ValueError as ve:
-        logging.warn(str(ve))
+        logging.warning(str(ve))
         raise HTTPException(
             detail=generate_error_message_for_api_keys(ve),
             status_code=400,
         )
     except Exception as e:
-        logging.warn(str(e))
+        logging.warning(str(e))
         raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
 
 
@@ -190,13 +190,13 @@ async def query_an_app(body: QueryApp, app_id: str, db: Session = Depends(get_db
         response = app.query(body.query)
         return DefaultResponse(response=response)
     except ValueError as ve:
-        logging.warn(str(ve))
+        logging.warning(str(ve))
         raise HTTPException(
             detail=generate_error_message_for_api_keys(ve),
             status_code=400,
         )
     except Exception as e:
-        logging.warn(str(e))
+        logging.warning(str(e))
         raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
 
 
@@ -273,13 +273,13 @@ async def deploy_app(body: DeployAppRequest, app_id: str, db: Session = Depends(
         app.deploy()
         return DefaultResponse(response="App deployed successfully.")
     except ValueError as ve:
-        logging.warn(str(ve))
+        logging.warning(str(ve))
         raise HTTPException(
             detail=generate_error_message_for_api_keys(ve),
             status_code=400,
         )
     except Exception as e:
-        logging.warn(str(e))
+        logging.warning(str(e))
         raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
 
 

+ 2 - 0
tests/chunkers/test_chunkers.py

@@ -9,6 +9,7 @@ from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.postgres import PostgresChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
 from embedchain.chunkers.sitemap import SitemapChunker
+from embedchain.chunkers.slack import SlackChunker
 from embedchain.chunkers.table import TableChunker
 from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.web_page import WebPageChunker
@@ -35,6 +36,7 @@ chunker_common_config = {
     OpenAPIChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+    SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
 }
 
 

+ 47 - 0
tests/loaders/test_slack.py

@@ -0,0 +1,47 @@
+import pytest
+
+from embedchain.loaders.slack import SlackLoader
+
+
+@pytest.fixture
+def slack_loader(mocker, monkeypatch):
+    # Mocking necessary dependencies
+    mocker.patch("slack_sdk.WebClient")
+    mocker.patch("ssl.create_default_context")
+    mocker.patch("certifi.where")
+
+    monkeypatch.setenv("SLACK_USER_TOKEN", "slack_user_token")
+
+    return SlackLoader()
+
+
+def test_slack_loader_initialization(slack_loader):
+    assert slack_loader.client is not None
+    assert slack_loader.config == {"base_url": "https://www.slack.com/api/"}
+
+
+def test_slack_loader_setup_loader(slack_loader):
+    slack_loader._setup_loader({"base_url": "https://custom.slack.api/"})
+
+    assert slack_loader.client is not None
+
+
+def test_slack_loader_check_query(slack_loader):
+    valid_json_query = "test_query"
+    invalid_query = 123
+
+    slack_loader._check_query(valid_json_query)
+
+    with pytest.raises(ValueError):
+        slack_loader._check_query(invalid_query)
+
+
+def test_slack_loader_load_data(slack_loader, mocker):
+    valid_json_query = "in:random"
+
+    mocker.patch.object(slack_loader.client, "search_messages", return_value={"messages": {}})
+
+    result = slack_loader.load_data(valid_json_query)
+
+    assert "doc_id" in result
+    assert "data" in result