test_discourse.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import pytest
  2. import requests
  3. from embedchain.loaders.discourse import DiscourseLoader
  4. @pytest.fixture
  5. def discourse_loader_config():
  6. return {
  7. "domain": "https://example.com/",
  8. }
  9. @pytest.fixture
  10. def discourse_loader(discourse_loader_config):
  11. return DiscourseLoader(config=discourse_loader_config)
  12. def test_discourse_loader_init_with_valid_config():
  13. config = {"domain": "https://example.com/"}
  14. loader = DiscourseLoader(config=config)
  15. assert loader.domain == "https://example.com/"
  16. def test_discourse_loader_init_with_missing_config():
  17. with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
  18. DiscourseLoader()
  19. def test_discourse_loader_init_with_missing_domain():
  20. config = {"another_key": "value"}
  21. with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
  22. DiscourseLoader(config=config)
  23. def test_discourse_loader_check_query_with_valid_query(discourse_loader):
  24. discourse_loader._check_query("sample query")
  25. def test_discourse_loader_check_query_with_empty_query(discourse_loader):
  26. with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
  27. discourse_loader._check_query("")
  28. def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
  29. with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
  30. discourse_loader._check_query(123)
  31. def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
  32. def mock_get(*args, **kwargs):
  33. class MockResponse:
  34. def json(self):
  35. return {"raw": "Sample post content"}
  36. def raise_for_status(self):
  37. pass
  38. return MockResponse()
  39. monkeypatch.setattr(requests, "get", mock_get)
  40. post_data = discourse_loader._load_post(123)
  41. assert post_data["content"] == "Sample post content"
  42. assert "meta_data" in post_data
  43. def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch, caplog):
  44. def mock_get(*args, **kwargs):
  45. class MockResponse:
  46. def raise_for_status(self):
  47. raise requests.exceptions.RequestException("Test error")
  48. return MockResponse()
  49. monkeypatch.setattr(requests, "get", mock_get)
  50. discourse_loader._load_post(123)
  51. assert "Failed to load post" in caplog.text
  52. def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
  53. def mock_get(*args, **kwargs):
  54. class MockResponse:
  55. def json(self):
  56. return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
  57. def raise_for_status(self):
  58. pass
  59. return MockResponse()
  60. monkeypatch.setattr(requests, "get", mock_get)
  61. def mock_load_post(*args, **kwargs):
  62. return {
  63. "content": "Sample post content",
  64. "meta_data": {
  65. "url": "https://example.com/posts/123.json",
  66. "created_at": "2021-01-01",
  67. "username": "test_user",
  68. "topic_slug": "test_topic",
  69. "score": 10,
  70. },
  71. }
  72. monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
  73. data = discourse_loader.load_data("sample query")
  74. assert len(data["data"]) == 3
  75. assert data["data"][0]["content"] == "Sample post content"
  76. assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
  77. assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
  78. assert data["data"][0]["meta_data"]["username"] == "test_user"
  79. assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
  80. assert data["data"][0]["meta_data"]["score"] == 10