discourse.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import hashlib
  2. import logging
  3. import time
  4. from typing import Any, Optional
  5. import requests
  6. from embedchain.loaders.base_loader import BaseLoader
  7. from embedchain.utils.misc import clean_string
  8. logger = logging.getLogger(__name__)
  9. class DiscourseLoader(BaseLoader):
  10. def __init__(self, config: Optional[dict[str, Any]] = None):
  11. super().__init__()
  12. if not config:
  13. raise ValueError(
  14. "DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/components/data-sources/discourse`" # noqa: E501
  15. )
  16. self.domain = config.get("domain")
  17. if not self.domain:
  18. raise ValueError(
  19. "DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/components/data-sources/discourse`" # noqa: E501
  20. )
  21. def _check_query(self, query):
  22. if not query or not isinstance(query, str):
  23. raise ValueError(
  24. "DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/components/data-sources/discourse`" # noqa: E501
  25. )
  26. def _load_post(self, post_id):
  27. post_url = f"{self.domain}posts/{post_id}.json"
  28. response = requests.get(post_url)
  29. try:
  30. response.raise_for_status()
  31. except Exception as e:
  32. logger.error(f"Failed to load post {post_id}: {e}")
  33. return
  34. response_data = response.json()
  35. post_contents = clean_string(response_data.get("raw"))
  36. metadata = {
  37. "url": post_url,
  38. "created_at": response_data.get("created_at", ""),
  39. "username": response_data.get("username", ""),
  40. "topic_slug": response_data.get("topic_slug", ""),
  41. "score": response_data.get("score", ""),
  42. }
  43. data = {
  44. "content": post_contents,
  45. "meta_data": metadata,
  46. }
  47. return data
  48. def load_data(self, query):
  49. self._check_query(query)
  50. data = []
  51. data_contents = []
  52. logger.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
  53. search_url = f"{self.domain}search.json?q={query}"
  54. response = requests.get(search_url)
  55. try:
  56. response.raise_for_status()
  57. except Exception as e:
  58. raise ValueError(f"Failed to search query {query}: {e}")
  59. response_data = response.json()
  60. post_ids = response_data.get("grouped_search_result").get("post_ids")
  61. for id in post_ids:
  62. post_data = self._load_post(id)
  63. if post_data:
  64. data.append(post_data)
  65. data_contents.append(post_data.get("content"))
  66. # Sleep for 0.4 sec, to avoid rate limiting. Check `https://meta.discourse.org/t/api-rate-limits/208405/6`
  67. time.sleep(0.4)
  68. doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
  69. response_data = {"doc_id": doc_id, "data": data}
  70. return response_data