Quellcode durchsuchen

[Feature] Discourse Loader (#948)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel vor 1 Jahr
Ursprung
Commit
95c0d47236

+ 44 - 0
docs/data-sources/discourse.mdx

@@ -0,0 +1,44 @@
+---
+title: '🗨️ Discourse'
+---
+
+You can now easily load data from your community built with [Discourse](https://discourse.org/).
+
+## Example
+
+1. Setup the Discourse Loader with your community url.
+```Python
+from embedchain.loaders.discourse import DiscourseLoader
+
+dicourse_loader = DiscourseLoader(config={"domain": "https://community.openai.com"})
+```
+
+2. Once you setup the loader, you can create an app and load data using the above discourse loader
+```Python
+import os
+from embedchain.pipeline import Pipeline as App
+
+os.environ["OPENAI_API_KEY"] = "sk-xxx"
+
+app = App()
+
+app.add("openai", data_type="discourse", loader=dicourse_loader)
+
+question = "Where can I find the OpenAI API status page?"
+app.query(question)
+# Answer: You can find the OpenAI API status page at https:/status.openai.com/.
+```
+
+NOTE: The `add` function of the app will accept any executable search query to load data. Refer [Discourse API Docs](https://docs.discourse.org/#tag/Search) to learn more about search queries.
+
+3. We automatically create a chunker to chunk your discourse data, however if you wish to provide your own chunker class. Here is how you can do that:
+```Python
+
+from embedchain.chunkers.discourse import DiscourseChunker
+from embedchain.config.add_config import ChunkerConfig
+
+discourse_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+discourse_chunker = DiscourseChunker(config=discourse_chunker_config)
+
+app.add("openai", data_type='discourse', loader=dicourse_loader, chunker=discourse_chunker)
+```

+ 2 - 1
docs/data-sources/overview.mdx

@@ -18,11 +18,12 @@ Embedchain comes with built-in support for various data sources. We handle the c
   <Card title="🌐📄 web page" href="/data-sources/web-page"></Card>
   <Card title="🧾 xml" href="/data-sources/xml"></Card>
   <Card title="🙌 OpenAPI" href="/data-sources/openapi"></Card>
-  <Card title="🎥📺 youtube video" href="/data-sources/youtube-video"></Card>
+  <Card title="📺 youtube video" href="/data-sources/youtube-video"></Card>
   <Card title="📬 Gmail" href="/data-sources/gmail"></Card>
   <Card title="🐘 Postgres" href="/data-sources/postgres"></Card>
   <Card title="🐬 MySQL" href="/data-sources/mysql"></Card>
   <Card title="🤖 Slack" href="/data-sources/slack"></Card>
+  <Card title="🗨️ Discourse" href="/data-sources/discourse"></Card>
 </CardGroup>
 
 <br/ >

+ 1 - 1
docs/data-sources/youtube-video.mdx

@@ -1,5 +1,5 @@
 ---
-title: '🎥📺 Youtube video'
+title: '📺 Youtube video'
 ---
 
 

+ 22 - 0
embedchain/chunkers/discourse.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class DiscourseChunker(BaseChunker):
+    """Chunker for discourse."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 2 - 0
embedchain/data_formatter/data_formatter.py

@@ -70,6 +70,7 @@ class DataFormatter(JSONSerializable):
                 DataType.POSTGRES,
                 DataType.MYSQL,
                 DataType.SLACK,
+                DataType.DISCOURSE,
             ]
         )
 
@@ -110,6 +111,7 @@ class DataFormatter(JSONSerializable):
             DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
             DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker",
             DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
+            DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker",
         }
 
         if data_type in chunker_classes:

+ 2 - 1
embedchain/embedchain.py

@@ -16,7 +16,8 @@ from embedchain.embedder.base import BaseEmbedder
 from embedchain.helper.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
+from embedchain.models.data_type import (DataType, DirectDataType,
+                                         IndirectDataType, SpecialDataType)
 from embedchain.telemetry.posthog import AnonymousTelemetry
 from embedchain.utils import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB

+ 72 - 0
embedchain/loaders/discourse.py

@@ -0,0 +1,72 @@
+import concurrent.futures
+import hashlib
+import logging
+from typing import Any, Dict, Optional
+
+import requests
+
+from embedchain.loaders.base_loader import BaseLoader
+from embedchain.utils import clean_string
+
+
+class DiscourseLoader(BaseLoader):
+    def __init__(self, config: Optional[Dict[str, Any]] = None):
+        super().__init__()
+        if not config:
+            raise ValueError(
+                "DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`"  # noqa: E501
+            )
+
+        self.domain = config.get("domain")
+        if not self.domain:
+            raise ValueError(
+                "DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`"  # noqa: E501
+            )
+
+    def _check_query(self, query):
+        if not query or not isinstance(query, str):
+            raise ValueError(
+                "DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`"  # noqa: E501
+            )
+
+    def _load_post(self, post_id):
+        post_url = f"{self.domain}/posts/{post_id}.json"
+        response = requests.get(post_url)
+        response.raise_for_status()
+        response_data = response.json()
+        post_contents = clean_string(response_data.get("raw"))
+        meta_data = {
+            "url": post_url,
+            "created_at": response_data.get("created_at", ""),
+            "username": response_data.get("username", ""),
+            "topic_slug": response_data.get("topic_slug", ""),
+            "score": response_data.get("score", ""),
+        }
+        data = {
+            "content": post_contents,
+            "meta_data": meta_data,
+        }
+        return data
+
+    def load_data(self, query):
+        self._check_query(query)
+        data = []
+        data_contents = []
+        logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
+        search_url = f"{self.domain}/search.json?q={query}"
+        response = requests.get(search_url)
+        response.raise_for_status()
+        response_data = response.json()
+        post_ids = response_data.get("grouped_search_result").get("post_ids")
+        with concurrent.futures.ThreadPoolExecutor() as executor:
+            future_to_post_id = {executor.submit(self._load_post, post_id): post_id for post_id in post_ids}
+            for future in concurrent.futures.as_completed(future_to_post_id):
+                post_id = future_to_post_id[future]
+                try:
+                    post_data = future.result()
+                    data.append(post_data)
+                except Exception as e:
+                    logging.error(f"Failed to load post {post_id}: {e}")
+        doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
+        response_data = {"doc_id": doc_id, "data": data}
+        return response_data

+ 2 - 0
embedchain/models/data_type.py

@@ -32,6 +32,7 @@ class IndirectDataType(Enum):
     POSTGRES = "postgres"
     MYSQL = "mysql"
     SLACK = "slack"
+    DISCOURSE = "discourse"
 
 
 class SpecialDataType(Enum):
@@ -63,3 +64,4 @@ class DataType(Enum):
     POSTGRES = IndirectDataType.POSTGRES.value
     MYSQL = IndirectDataType.MYSQL.value
     SLACK = IndirectDataType.SLACK.value
+    DISCOURSE = IndirectDataType.DISCOURSE.value

+ 55 - 0
embedchain/utils.py

@@ -5,11 +5,66 @@ import re
 import string
 from typing import Any
 
+from bs4 import BeautifulSoup
 from schema import Optional, Or, Schema
 
 from embedchain.models.data_type import DataType
 
 
+def parse_content(content, type):
+    implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
+    if type not in implemented:
+        raise ValueError(f"Parser type {type} not implemented. Please choose one of {implemented}")
+
+    soup = BeautifulSoup(content, type)
+    original_size = len(str(soup.get_text()))
+
+    tags_to_exclude = [
+        "nav",
+        "aside",
+        "form",
+        "header",
+        "noscript",
+        "svg",
+        "canvas",
+        "footer",
+        "script",
+        "style",
+    ]
+    for tag in soup(tags_to_exclude):
+        tag.decompose()
+
+    ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
+    for id in ids_to_exclude:
+        tags = soup.find_all(id=id)
+        for tag in tags:
+            tag.decompose()
+
+    classes_to_exclude = [
+        "elementor-location-header",
+        "navbar-header",
+        "nav",
+        "header-sidebar-wrapper",
+        "blog-sidebar-wrapper",
+        "related-posts",
+    ]
+    for class_name in classes_to_exclude:
+        tags = soup.find_all(class_=class_name)
+        for tag in tags:
+            tag.decompose()
+
+    content = soup.get_text()
+    content = clean_string(content)
+
+    cleaned_size = len(content)
+    if original_size != 0:
+        logging.info(
+            f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)"  # noqa:E501
+        )
+
+    return content
+
+
 def clean_string(text):
     """
     This function takes in a string and performs a series of text cleaning operations.

+ 2 - 0
tests/chunkers/test_chunkers.py

@@ -1,3 +1,4 @@
+from embedchain.chunkers.discourse import DiscourseChunker
 from embedchain.chunkers.docs_site import DocsSiteChunker
 from embedchain.chunkers.docx_file import DocxFileChunker
 from embedchain.chunkers.gmail import GmailChunker
@@ -37,6 +38,7 @@ chunker_common_config = {
     GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+    DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
 }
 
 

+ 2 - 1
tests/helper_classes/test_json_serializable.py

@@ -4,7 +4,8 @@ from string import Template
 
 from embedchain import App
 from embedchain.config import AppConfig, BaseLlmConfig
-from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
+from embedchain.helper.json_serializable import (JSONSerializable,
+                                                 register_deserializable)
 
 
 class TestJsonSerializable(unittest.TestCase):

+ 118 - 0
tests/loaders/test_discourse.py

@@ -0,0 +1,118 @@
+import pytest
+import requests
+
+from embedchain.loaders.discourse import DiscourseLoader
+
+
+@pytest.fixture
+def discourse_loader_config():
+    return {
+        "domain": "https://example.com",
+    }
+
+
+@pytest.fixture
+def discourse_loader(discourse_loader_config):
+    return DiscourseLoader(config=discourse_loader_config)
+
+
+def test_discourse_loader_init_with_valid_config():
+    config = {"domain": "https://example.com"}
+    loader = DiscourseLoader(config=config)
+    assert loader.domain == "https://example.com"
+
+
+def test_discourse_loader_init_with_missing_config():
+    with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
+        DiscourseLoader()
+
+
+def test_discourse_loader_init_with_missing_domain():
+    config = {"another_key": "value"}
+    with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
+        DiscourseLoader(config=config)
+
+
+def test_discourse_loader_check_query_with_valid_query(discourse_loader):
+    discourse_loader._check_query("sample query")
+
+
+def test_discourse_loader_check_query_with_empty_query(discourse_loader):
+    with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
+        discourse_loader._check_query("")
+
+
+def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
+    with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
+        discourse_loader._check_query(123)
+
+
+def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
+    def mock_get(*args, **kwargs):
+        class MockResponse:
+            def json(self):
+                return {"raw": "Sample post content"}
+
+            def raise_for_status(self):
+                pass
+
+        return MockResponse()
+
+    monkeypatch.setattr(requests, "get", mock_get)
+
+    post_data = discourse_loader._load_post(123)
+
+    assert post_data["content"] == "Sample post content"
+    assert "meta_data" in post_data
+
+
+def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch):
+    def mock_get(*args, **kwargs):
+        class MockResponse:
+            def raise_for_status(self):
+                raise requests.exceptions.RequestException("Test error")
+
+        return MockResponse()
+
+    monkeypatch.setattr(requests, "get", mock_get)
+
+    with pytest.raises(Exception, match="Test error"):
+        discourse_loader._load_post(123)
+
+
+def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
+    def mock_get(*args, **kwargs):
+        class MockResponse:
+            def json(self):
+                return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
+
+            def raise_for_status(self):
+                pass
+
+        return MockResponse()
+
+    monkeypatch.setattr(requests, "get", mock_get)
+
+    def mock_load_post(*args, **kwargs):
+        return {
+            "content": "Sample post content",
+            "meta_data": {
+                "url": "https://example.com/posts/123.json",
+                "created_at": "2021-01-01",
+                "username": "test_user",
+                "topic_slug": "test_topic",
+                "score": 10,
+            },
+        }
+
+    monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
+
+    data = discourse_loader.load_data("sample query")
+
+    assert len(data["data"]) == 3
+    assert data["data"][0]["content"] == "Sample post content"
+    assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
+    assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
+    assert data["data"][0]["meta_data"]["username"] == "test_user"
+    assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
+    assert data["data"][0]["meta_data"]["score"] == 10