test_docs_site.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import hashlib
  2. from unittest.mock import Mock, patch
  3. import pytest
  4. from requests import Response
  5. from embedchain.loaders.docs_site_loader import DocsSiteLoader
  6. @pytest.fixture
  7. def mock_requests_get():
  8. with patch("requests.get") as mock_get:
  9. yield mock_get
  10. @pytest.fixture
  11. def docs_site_loader():
  12. return DocsSiteLoader()
  13. def test_get_child_links_recursive(mock_requests_get, docs_site_loader):
  14. mock_response = Mock()
  15. mock_response.status_code = 200
  16. mock_response.text = """
  17. <html>
  18. <a href="/page1">Page 1</a>
  19. <a href="/page2">Page 2</a>
  20. </html>
  21. """
  22. mock_requests_get.return_value = mock_response
  23. docs_site_loader._get_child_links_recursive("https://example.com")
  24. assert len(docs_site_loader.visited_links) == 2
  25. assert "https://example.com/page1" in docs_site_loader.visited_links
  26. assert "https://example.com/page2" in docs_site_loader.visited_links
  27. def test_get_child_links_recursive_status_not_200(mock_requests_get, docs_site_loader):
  28. mock_response = Mock()
  29. mock_response.status_code = 404
  30. mock_requests_get.return_value = mock_response
  31. docs_site_loader._get_child_links_recursive("https://example.com")
  32. assert len(docs_site_loader.visited_links) == 0
  33. def test_get_all_urls(mock_requests_get, docs_site_loader):
  34. mock_response = Mock()
  35. mock_response.status_code = 200
  36. mock_response.text = """
  37. <html>
  38. <a href="/page1">Page 1</a>
  39. <a href="/page2">Page 2</a>
  40. <a href="https://example.com/external">External</a>
  41. </html>
  42. """
  43. mock_requests_get.return_value = mock_response
  44. all_urls = docs_site_loader._get_all_urls("https://example.com")
  45. assert len(all_urls) == 3
  46. assert "https://example.com/page1" in all_urls
  47. assert "https://example.com/page2" in all_urls
  48. assert "https://example.com/external" in all_urls
  49. def test_load_data_from_url(mock_requests_get, docs_site_loader):
  50. mock_response = Mock()
  51. mock_response.status_code = 200
  52. mock_response.content = """
  53. <html>
  54. <nav>
  55. <h1>Navigation</h1>
  56. </nav>
  57. <article class="bd-article">
  58. <p>Article Content</p>
  59. </article>
  60. </html>
  61. """.encode()
  62. mock_requests_get.return_value = mock_response
  63. data = docs_site_loader._load_data_from_url("https://example.com/page1")
  64. assert len(data) == 1
  65. assert data[0]["content"] == "Article Content"
  66. assert data[0]["meta_data"]["url"] == "https://example.com/page1"
  67. def test_load_data_from_url_status_not_200(mock_requests_get, docs_site_loader):
  68. mock_response = Mock()
  69. mock_response.status_code = 404
  70. mock_requests_get.return_value = mock_response
  71. data = docs_site_loader._load_data_from_url("https://example.com/page1")
  72. assert data == []
  73. assert len(data) == 0
  74. def test_load_data(mock_requests_get, docs_site_loader):
  75. mock_response = Response()
  76. mock_response.status_code = 200
  77. mock_response._content = """
  78. <html>
  79. <a href="/page1">Page 1</a>
  80. <a href="/page2">Page 2</a>
  81. """.encode()
  82. mock_requests_get.return_value = mock_response
  83. url = "https://example.com"
  84. data = docs_site_loader.load_data(url)
  85. expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
  86. assert len(data["data"]) == 2
  87. assert data["doc_id"] == expected_doc_id
  88. def test_if_response_status_not_200(mock_requests_get, docs_site_loader):
  89. mock_response = Response()
  90. mock_response.status_code = 404
  91. mock_requests_get.return_value = mock_response
  92. url = "https://example.com"
  93. data = docs_site_loader.load_data(url)
  94. expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
  95. assert len(data["data"]) == 0
  96. assert data["doc_id"] == expected_doc_id