docs_site_loader.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import hashlib
  2. import logging
  3. from urllib.parse import urljoin, urlparse
  4. import requests
  5. try:
  6. from bs4 import BeautifulSoup
  7. except ImportError:
  8. raise ImportError(
  9. 'DocsSite requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
  10. ) from None
  11. from embedchain.helpers.json_serializable import register_deserializable
  12. from embedchain.loaders.base_loader import BaseLoader
  13. logger = logging.getLogger(__name__)
  14. @register_deserializable
  15. class DocsSiteLoader(BaseLoader):
  16. def __init__(self):
  17. self.visited_links = set()
  18. def _get_child_links_recursive(self, url):
  19. if url in self.visited_links:
  20. return
  21. parsed_url = urlparse(url)
  22. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  23. current_path = parsed_url.path
  24. response = requests.get(url)
  25. if response.status_code != 200:
  26. logger.info(f"Failed to fetch the website: {response.status_code}")
  27. return
  28. soup = BeautifulSoup(response.text, "html.parser")
  29. all_links = (link.get("href") for link in soup.find_all("a", href=True))
  30. child_links = (link for link in all_links if link.startswith(current_path) and link != current_path)
  31. absolute_paths = set(urljoin(base_url, link) for link in child_links)
  32. self.visited_links.update(absolute_paths)
  33. [self._get_child_links_recursive(link) for link in absolute_paths if link not in self.visited_links]
  34. def _get_all_urls(self, url):
  35. self.visited_links = set()
  36. self._get_child_links_recursive(url)
  37. urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc]
  38. return urls
  39. @staticmethod
  40. def _load_data_from_url(url: str) -> list:
  41. response = requests.get(url)
  42. if response.status_code != 200:
  43. logger.info(f"Failed to fetch the website: {response.status_code}")
  44. return []
  45. soup = BeautifulSoup(response.content, "html.parser")
  46. selectors = [
  47. "article.bd-article",
  48. 'article[role="main"]',
  49. "div.md-content",
  50. 'div[role="main"]',
  51. "div.container",
  52. "div.section",
  53. "article",
  54. "main",
  55. ]
  56. output = []
  57. for selector in selectors:
  58. element = soup.select_one(selector)
  59. if element:
  60. content = element.prettify()
  61. break
  62. else:
  63. content = soup.get_text()
  64. soup = BeautifulSoup(content, "html.parser")
  65. ignored_tags = [
  66. "nav",
  67. "aside",
  68. "form",
  69. "header",
  70. "noscript",
  71. "svg",
  72. "canvas",
  73. "footer",
  74. "script",
  75. "style",
  76. ]
  77. for tag in soup(ignored_tags):
  78. tag.decompose()
  79. content = " ".join(soup.stripped_strings)
  80. output.append(
  81. {
  82. "content": content,
  83. "meta_data": {"url": url},
  84. }
  85. )
  86. return output
  87. def load_data(self, url):
  88. all_urls = self._get_all_urls(url)
  89. output = []
  90. for u in all_urls:
  91. output.extend(self._load_data_from_url(u))
  92. doc_id = hashlib.sha256((" ".join(all_urls) + url).encode()).hexdigest()
  93. return {
  94. "doc_id": doc_id,
  95. "data": output,
  96. }