test_discourse.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import pytest
  2. import requests
  3. from embedchain.loaders.discourse import DiscourseLoader
  4. @pytest.fixture
  5. def discourse_loader_config():
  6. return {
  7. "domain": "https://example.com/",
  8. }
  9. @pytest.fixture
  10. def discourse_loader(discourse_loader_config):
  11. return DiscourseLoader(config=discourse_loader_config)
  12. def test_discourse_loader_init_with_valid_config():
  13. config = {"domain": "https://example.com/"}
  14. loader = DiscourseLoader(config=config)
  15. assert loader.domain == "https://example.com/"
  16. def test_discourse_loader_init_with_missing_config():
  17. with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
  18. DiscourseLoader()
  19. def test_discourse_loader_init_with_missing_domain():
  20. config = {"another_key": "value"}
  21. with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
  22. DiscourseLoader(config=config)
  23. def test_discourse_loader_check_query_with_valid_query(discourse_loader):
  24. discourse_loader._check_query("sample query")
  25. def test_discourse_loader_check_query_with_empty_query(discourse_loader):
  26. with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
  27. discourse_loader._check_query("")
  28. def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
  29. with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
  30. discourse_loader._check_query(123)
  31. def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
  32. def mock_get(*args, **kwargs):
  33. class MockResponse:
  34. def json(self):
  35. return {"raw": "Sample post content"}
  36. def raise_for_status(self):
  37. pass
  38. return MockResponse()
  39. monkeypatch.setattr(requests, "get", mock_get)
  40. post_data = discourse_loader._load_post(123)
  41. assert post_data["content"] == "Sample post content"
  42. assert "meta_data" in post_data
  43. def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
  44. def mock_get(*args, **kwargs):
  45. class MockResponse:
  46. def json(self):
  47. return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
  48. def raise_for_status(self):
  49. pass
  50. return MockResponse()
  51. monkeypatch.setattr(requests, "get", mock_get)
  52. def mock_load_post(*args, **kwargs):
  53. return {
  54. "content": "Sample post content",
  55. "meta_data": {
  56. "url": "https://example.com/posts/123.json",
  57. "created_at": "2021-01-01",
  58. "username": "test_user",
  59. "topic_slug": "test_topic",
  60. "score": 10,
  61. },
  62. }
  63. monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
  64. data = discourse_loader.load_data("sample query")
  65. assert len(data["data"]) == 3
  66. assert data["data"][0]["content"] == "Sample post content"
  67. assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
  68. assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
  69. assert data["data"][0]["meta_data"]["username"] == "test_user"
  70. assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
  71. assert data["data"][0]["meta_data"]["score"] == 10