discourse.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import concurrent.futures
  2. import hashlib
  3. import logging
  4. from typing import Any, Dict, Optional
  5. import requests
  6. from embedchain.loaders.base_loader import BaseLoader
  7. from embedchain.utils import clean_string
  8. class DiscourseLoader(BaseLoader):
  9. def __init__(self, config: Optional[Dict[str, Any]] = None):
  10. super().__init__()
  11. if not config:
  12. raise ValueError(
  13. "DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
  14. )
  15. self.domain = config.get("domain")
  16. if not self.domain:
  17. raise ValueError(
  18. "DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
  19. )
  20. def _check_query(self, query):
  21. if not query or not isinstance(query, str):
  22. raise ValueError(
  23. "DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/data-sources/discourse`" # noqa: E501
  24. )
  25. def _load_post(self, post_id):
  26. post_url = f"{self.domain}/posts/{post_id}.json"
  27. response = requests.get(post_url)
  28. response.raise_for_status()
  29. response_data = response.json()
  30. post_contents = clean_string(response_data.get("raw"))
  31. meta_data = {
  32. "url": post_url,
  33. "created_at": response_data.get("created_at", ""),
  34. "username": response_data.get("username", ""),
  35. "topic_slug": response_data.get("topic_slug", ""),
  36. "score": response_data.get("score", ""),
  37. }
  38. data = {
  39. "content": post_contents,
  40. "meta_data": meta_data,
  41. }
  42. return data
  43. def load_data(self, query):
  44. self._check_query(query)
  45. data = []
  46. data_contents = []
  47. logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
  48. search_url = f"{self.domain}/search.json?q={query}"
  49. response = requests.get(search_url)
  50. response.raise_for_status()
  51. response_data = response.json()
  52. post_ids = response_data.get("grouped_search_result").get("post_ids")
  53. with concurrent.futures.ThreadPoolExecutor() as executor:
  54. future_to_post_id = {executor.submit(self._load_post, post_id): post_id for post_id in post_ids}
  55. for future in concurrent.futures.as_completed(future_to_post_id):
  56. post_id = future_to_post_id[future]
  57. try:
  58. post_data = future.result()
  59. data.append(post_data)
  60. except Exception as e:
  61. logging.error(f"Failed to load post {post_id}: {e}")
  62. doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
  63. response_data = {"doc_id": doc_id, "data": data}
  64. return response_data