浏览代码

[fix]: Fix sitemap loader (#753)

Richard Awoyemi 1 年之前
父节点
当前提交
1741d3bef6

+ 22 - 0
embedchain/chunkers/sitemap.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class SitemapChunker(BaseChunker):
+    """Chunker for sitemap."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 2 - 0
embedchain/data_formatter/data_formatter.py

@@ -6,6 +6,7 @@ from embedchain.chunkers.mdx import MdxChunker
 from embedchain.chunkers.notion import NotionChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
+from embedchain.chunkers.sitemap import SitemapChunker
 from embedchain.chunkers.table import TableChunker
 from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.web_page import WebPageChunker
@@ -109,6 +110,7 @@ class DataFormatter(JSONSerializable):
             DataType.TEXT: TextChunker,
             DataType.DOCX: DocxFileChunker,
             DataType.DOCS_SITE: DocsSiteChunker,
+            DataType.SITEMAP: SitemapChunker,
             DataType.NOTION: NotionChunker,
             DataType.CSV: TableChunker,
             DataType.MDX: MdxChunker,

+ 2 - 3
embedchain/loaders/sitemap.py

@@ -36,9 +36,8 @@ class SitemapLoader(BaseLoader):
         for link in links:
             try:
                 each_load_data = web_page_loader.load_data(link)
-
-                if is_readable(each_load_data[0].get("content")):
-                    output.append(each_load_data)
+                if is_readable(each_load_data.get("data")[0].get("content")):
+                    output.append(each_load_data.get("data"))
                 else:
                     logging.warning(f"Page is not readable (too many invalid characters): {link}")
             except ParserRejectedMarkup as e:

+ 8 - 0
tests/embedchain/test_add.py

@@ -27,6 +27,14 @@ class TestApp(unittest.TestCase):
         self.app.add("https://example.com", metadata={"meta": "meta-data"})
         self.assertEqual(self.app.user_asks, [["https://example.com", "web_page", {"meta": "meta-data"}]])
 
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_add_sitemap(self):
+        """
+        In addition to the test_add function, this test checks that sitemaps can be added with the correct data type.
+        """
+        self.app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
+        self.assertEqual(self.app.user_asks, [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]])
+
     @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
     def test_add_forced_type(self):
         """