Procházet zdrojové kódy

[feat]: Add support for XML file format (#757)

Ojuswi Rastogi před 1 rokem
rodič
revize
540a0a3685

+ 13 - 0
docs/data-sources/xml.mdx

@@ -0,0 +1,13 @@
+---
+title: 'XML File'
+---
+
+### XML file
+
+To add any xml file, use the data_type as `xml`. Eg:
+
+```python
+app.add('content/data.xml')
+```
+
+Note: Only the text content of the xml file will be added to the app. The tags will be ignored.

+ 22 - 0
embedchain/chunkers/xml.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class XmlChunker(BaseChunker):
+    """Chunker for XML files."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=500, chunk_overlap=50, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 4 - 0
embedchain/data_formatter/data_formatter.py

@@ -9,6 +9,7 @@ from embedchain.chunkers.qna_pair import QnaPairChunker
 from embedchain.chunkers.table import TableChunker
 from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.web_page import WebPageChunker
+from embedchain.chunkers.xml import XmlChunker
 from embedchain.chunkers.youtube_video import YoutubeVideoChunker
 from embedchain.config import AddConfig
 from embedchain.config.add_config import ChunkerConfig, LoaderConfig
@@ -24,6 +25,7 @@ from embedchain.loaders.mdx import MdxLoader
 from embedchain.loaders.pdf_file import PdfFileLoader
 from embedchain.loaders.sitemap import SitemapLoader
 from embedchain.loaders.web_page import WebPageLoader
+from embedchain.loaders.xml import XmlLoader
 from embedchain.loaders.youtube_video import YoutubeVideoLoader
 from embedchain.models.data_type import DataType
 
@@ -67,6 +69,7 @@ class DataFormatter(JSONSerializable):
             DataType.TEXT: LocalTextLoader,
             DataType.DOCX: DocxFileLoader,
             DataType.SITEMAP: SitemapLoader,
+            DataType.XML: XmlLoader,
             DataType.DOCS_SITE: DocsSiteLoader,
             DataType.CSV: CsvLoader,
             DataType.MDX: MdxLoader,
@@ -110,6 +113,7 @@ class DataFormatter(JSONSerializable):
             DataType.CSV: TableChunker,
             DataType.MDX: MdxChunker,
             DataType.IMAGES: ImagesChunker,
+            DataType.XML: XmlChunker,
         }
         if data_type in chunker_classes:
             chunker_class: type = chunker_classes[data_type]

+ 26 - 0
embedchain/loaders/xml.py

@@ -0,0 +1,26 @@
+import hashlib
+
+from langchain.document_loaders import UnstructuredXMLLoader
+
+from embedchain.helper.json_serializable import register_deserializable
+from embedchain.loaders.base_loader import BaseLoader
+from embedchain.utils import clean_string
+
+
+@register_deserializable
+class XmlLoader(BaseLoader):
+    def load_data(self, xml_url):
+        """Load data from a XML file."""
+        loader = UnstructuredXMLLoader(xml_url)
+        data = loader.load()
+        content = data[0].page_content
+        content = clean_string(content)
+        meta_data = data[0].metadata
+        meta_data["url"] = meta_data["source"]
+        del meta_data["source"]
+        output = [{"content": content, "meta_data": meta_data}]
+        doc_id = hashlib.sha256((content + xml_url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": output,
+        }

+ 2 - 0
embedchain/models/data_type.py

@@ -18,6 +18,7 @@ class IndirectDataType(Enum):
     PDF_FILE = "pdf_file"
     WEB_PAGE = "web_page"
     SITEMAP = "sitemap"
+    XML = "xml"
     DOCX = "docx"
     DOCS_SITE = "docs_site"
     NOTION = "notion"
@@ -40,6 +41,7 @@ class DataType(Enum):
     PDF_FILE = IndirectDataType.PDF_FILE.value
     WEB_PAGE = IndirectDataType.WEB_PAGE.value
     SITEMAP = IndirectDataType.SITEMAP.value
+    XML = IndirectDataType.XML.value
     DOCX = IndirectDataType.DOCX.value
     DOCS_SITE = IndirectDataType.DOCS_SITE.value
     NOTION = IndirectDataType.NOTION.value

+ 5 - 2
embedchain/utils.py

@@ -128,8 +128,7 @@ def detect_datatype(source: Any) -> DataType:
     formatted_source = format_source(str(source), 30)
 
     if url:
-        from langchain.document_loaders.youtube import \
-            ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
+        from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -190,6 +189,10 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
             return DataType.CSV
 
+        if source.endswith(".xml"):
+            logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
+            return DataType.XML
+
         # If the source is a valid file, that's not detectable as a type, an error is raised.
         # It does not fallback to text.
         raise ValueError(

+ 1 - 0
pyproject.toml

@@ -106,6 +106,7 @@ fastapi-poe = { version = "0.0.16", optional = true }
 discord = { version = "^2.3.2", optional = true }
 slack-sdk = { version = "3.21.3", optional = true }
 docx2txt = "^0.8"
+unstructured = {extras = ["local-inference"], version = "^0.10.18"}
 pillow = { version = "10.0.1", optional = true }
 torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
 ftfy = { version = "6.1.1", optional = true }

+ 62 - 0
tests/loaders/test_xml.py

@@ -0,0 +1,62 @@
+import tempfile
+
+import pytest
+
+from embedchain.loaders.xml import XmlLoader
+
+# Taken from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/tests/integration_tests/examples/factbook.xml
+SAMPLE_XML = """<?xml version="1.0" encoding="UTF-8"?>
+<factbook>
+  <country>
+    <name>United States</name>
+    <capital>Washington, DC</capital>
+    <leader>Joe Biden</leader>
+    <sport>Baseball</sport>
+  </country>
+  <country>
+    <name>Canada</name>
+    <capital>Ottawa</capital>
+    <leader>Justin Trudeau</leader>
+    <sport>Hockey</sport>
+  </country>
+  <country>
+    <name>France</name>
+    <capital>Paris</capital>
+    <leader>Emmanuel Macron</leader>
+    <sport>Soccer</sport>
+  </country>
+  <country>
+    <name>Trinidad &amp; Tobado</name>
+    <capital>Port of Spain</capital>
+    <leader>Keith Rowley</leader>
+    <sport>Track &amp; Field</sport>
+  </country>
+</factbook>"""
+
+
+@pytest.mark.parametrize("xml", [SAMPLE_XML])
+def test_load_data(xml: str):
+    """
+    Test XML loader
+
+    Tests that XML file is loaded, metadata is correct and content is correct
+    """
+    # Creating temporary XML file
+    with tempfile.NamedTemporaryFile(mode="w+") as tmpfile:
+        tmpfile.write(xml)
+
+        tmpfile.seek(0)
+        filename = tmpfile.name
+
+        # Loading CSV using XmlLoader
+        loader = XmlLoader()
+        result = loader.load_data(filename)
+        data = result["data"]
+
+        # Assertions
+        assert len(data) == 1
+        assert "United States Washington, DC Joe Biden" in data[0]["content"]
+        assert "Canada Ottawa Justin Trudeau" in data[0]["content"]
+        assert "France Paris Emmanuel Macron" in data[0]["content"]
+        assert "Trinidad & Tobado Port of Spain Keith Rowley" in data[0]["content"]
+        assert data[0]["meta_data"]["url"] == filename