소스 검색

Add local file path support to SitemapLoader (#954)

Prikshit 1 년 전
부모
커밋
106a338371
1개의 변경된 파일16개의 추가작업 그리고 12개의 파일을 삭제
  1. 16 12
      embedchain/loaders/sitemap.py

+ 16 - 12
embedchain/loaders/sitemap.py

@@ -1,6 +1,7 @@
 import concurrent.futures
 import hashlib
 import logging
+import os
 from urllib.parse import urlparse
 
 import requests
@@ -22,31 +23,34 @@ from embedchain.loaders.web_page import WebPageLoader
 @register_deserializable
 class SitemapLoader(BaseLoader):
     """
-    This method takes a sitemap URL as input and retrieves
+    This method takes a sitemap URL or local file path as input and retrieves
     all the URLs to use the WebPageLoader to load content
     of each page.
     """
 
-    def load_data(self, sitemap_url):
+    def load_data(self, sitemap_source):
         output = []
         web_page_loader = WebPageLoader()
 
-        if urlparse(sitemap_url).scheme not in ["file", "http", "https"]:
-            raise ValueError("Not a valid URL.")
-
-        if urlparse(sitemap_url).scheme in ["http", "https"]:
-            response = requests.get(sitemap_url)
-            response.raise_for_status()
-            soup = BeautifulSoup(response.text, "xml")
-        else:
-            with open(sitemap_url, "r") as file:
+        if urlparse(sitemap_source).scheme in ("http", "https"):
+            try:
+                response = requests.get(sitemap_source)
+                response.raise_for_status()
+                soup = BeautifulSoup(response.text, "xml")
+            except requests.RequestException as e:
+                logging.error(f"Error fetching sitemap from URL: {e}")
+                return
+        elif os.path.isfile(sitemap_source):
+            with open(sitemap_source, "r") as file:
                 soup = BeautifulSoup(file, "xml")
+        else:
+            raise ValueError("Invalid sitemap source. Please provide a valid URL or local file path.")
 
         links = [link.text for link in soup.find_all("loc") if link.parent.name == "url"]
         if len(links) == 0:
             links = [link.text for link in soup.find_all("loc")]
 
-        doc_id = hashlib.sha256((" ".join(links) + sitemap_url).encode()).hexdigest()
+        doc_id = hashlib.sha256((" ".join(links) + sitemap_source).encode()).hexdigest()
 
         def load_web_page(link):
             try: