Selaa lähdekoodia

[Improvement] fix discourse loader to avoid rate limit (#953)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 vuosi sitten
vanhempi
commit
c14bd7b73b

+ 1 - 1
docs/data-sources/discourse.mdx

@@ -22,7 +22,7 @@ os.environ["OPENAI_API_KEY"] = "sk-xxx"
 
 app = App()
 
-app.add("openai", data_type="discourse", loader=dicourse_loader)
+app.add("openai after:2023-10-1", data_type="discourse", loader=dicourse_loader)
 
 question = "Where can I find the OpenAI API status page?"
 app.query(question)

+ 17 - 12
embedchain/loaders/discourse.py

@@ -1,6 +1,6 @@
-import concurrent.futures
 import hashlib
 import logging
+import time
 from typing import Any, Dict, Optional
 
 import requests
@@ -32,7 +32,11 @@ class DiscourseLoader(BaseLoader):
     def _load_post(self, post_id):
         post_url = f"{self.domain}posts/{post_id}.json"
         response = requests.get(post_url)
-        response.raise_for_status()
+        try:
+            response.raise_for_status()
+        except Exception as e:
+            logging.error(f"Failed to load post {post_id}: {e}")
+            return
         response_data = response.json()
         post_contents = clean_string(response_data.get("raw"))
         meta_data = {
@@ -55,18 +59,19 @@ class DiscourseLoader(BaseLoader):
         logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
         search_url = f"{self.domain}search.json?q={query}"
         response = requests.get(search_url)
-        response.raise_for_status()
+        try:
+            response.raise_for_status()
+        except Exception as e:
+            raise ValueError(f"Failed to search query {query}: {e}")
         response_data = response.json()
         post_ids = response_data.get("grouped_search_result").get("post_ids")
-        with concurrent.futures.ThreadPoolExecutor() as executor:
-            future_to_post_id = {executor.submit(self._load_post, post_id): post_id for post_id in post_ids}
-            for future in concurrent.futures.as_completed(future_to_post_id):
-                post_id = future_to_post_id[future]
-                try:
-                    post_data = future.result()
-                    data.append(post_data)
-                except Exception as e:
-                    logging.error(f"Failed to load post {post_id}: {e}")
+        for id in post_ids:
+            post_data = self._load_post(id)
+            if post_data:
+                data.append(post_data)
+                data_contents.append(post_data.get("content"))
+            # Sleep for 0.4 sec, to avoid rate limiting. Check `https://meta.discourse.org/t/api-rate-limits/208405/6`
+            time.sleep(0.4)
         doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
         response_data = {"doc_id": doc_id, "data": data}
         return response_data

+ 2 - 1
embedchain/loaders/substack.py

@@ -1,6 +1,7 @@
-import time
 import hashlib
 import logging
+import time
+
 import requests
 
 from embedchain.helper.json_serializable import register_deserializable

+ 1 - 1
embedchain/loaders/youtube_video.py

@@ -19,7 +19,7 @@ class YoutubeVideoLoader(BaseLoader):
         doc = loader.load()
         output = []
         if not len(doc):
-            raise ValueError("No data found")
+            raise ValueError(f"No data found for url: {url}")
         content = doc[0].page_content
         content = clean_string(content)
         meta_data = doc[0].metadata

+ 4 - 3
tests/loaders/test_discourse.py

@@ -66,7 +66,7 @@ def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeyp
     assert "meta_data" in post_data
 
 
-def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch):
+def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch, caplog):
     def mock_get(*args, **kwargs):
         class MockResponse:
             def raise_for_status(self):
@@ -76,8 +76,9 @@ def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monke
 
     monkeypatch.setattr(requests, "get", mock_get)
 
-    with pytest.raises(Exception, match="Test error"):
-        discourse_loader._load_post(123)
+    discourse_loader._load_post(123)
+
+    assert "Failed to load post" in caplog.text
 
 
 def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):