浏览代码

feat: add support for mdx file (#604)

Taranjeet Singh 1 年之前
父节点
当前提交
36b26e08c3

+ 8 - 0
docs/advanced/data_types.mdx

@@ -102,6 +102,14 @@ app.add("my-page-cfbc134ca6464fc980d0391613959196", "notion")
 app.add("https://www.notion.so/my-page-cfbc134ca6464fc980d0391613959196", "notion")
 ```
 
+### Mdx file
+
+To add any mdx file to your app, use the data_type (first argument to `.add()` method) as `mdx`. Note that this supports support mdx file present on machine, so this should be a file path. Eg:
+
+```python
+app.add('path/to/file.mdx', data_type='mdx')
+```
+
 ## Local Data Types
 
 ### Text

+ 22 - 0
embedchain/chunkers/mdx.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class MdxChunker(BaseChunker):
+    """Chunker for mdx files."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 4 - 0
embedchain/data_formatter/data_formatter.py

@@ -1,6 +1,7 @@
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.docs_site import DocsSiteChunker
 from embedchain.chunkers.docx_file import DocxFileChunker
+from embedchain.chunkers.mdx import MdxChunker
 from embedchain.chunkers.notion import NotionChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
@@ -17,6 +18,7 @@ from embedchain.loaders.docs_site_loader import DocsSiteLoader
 from embedchain.loaders.docx_file import DocxFileLoader
 from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
 from embedchain.loaders.local_text import LocalTextLoader
+from embedchain.loaders.mdx import MdxLoader
 from embedchain.loaders.pdf_file import PdfFileLoader
 from embedchain.loaders.sitemap import SitemapLoader
 from embedchain.loaders.web_page import WebPageLoader
@@ -65,6 +67,7 @@ class DataFormatter(JSONSerializable):
             DataType.SITEMAP: SitemapLoader,
             DataType.DOCS_SITE: DocsSiteLoader,
             DataType.CSV: CsvLoader,
+            DataType.MDX: MdxLoader,
         }
         lazy_loaders = {DataType.NOTION}
         if data_type in loaders:
@@ -103,6 +106,7 @@ class DataFormatter(JSONSerializable):
             DataType.DOCS_SITE: DocsSiteChunker,
             DataType.NOTION: NotionChunker,
             DataType.CSV: TableChunker,
+            DataType.MDX: MdxChunker,
         }
         if data_type in chunker_classes:
             chunker_class: type = chunker_classes[data_type]

+ 28 - 0
embedchain/loaders/mdx.py

@@ -0,0 +1,28 @@
+import hashlib
+
+from langchain.document_loaders import PyPDFLoader
+
+from embedchain.helper.json_serializable import register_deserializable
+from embedchain.loaders.base_loader import BaseLoader
+from embedchain.utils import clean_string
+
+
+@register_deserializable
+class MdxLoader(BaseLoader):
+    def load_data(self, url):
+        """Load data from a mdx file."""
+        with open(url, 'r', encoding="utf-8") as infile:
+            content = infile.read()
+        meta_data = {
+            "url": url,
+        }
+        doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": [
+                {
+                    "content": content,
+                    "meta_data": meta_data,
+                }
+            ],
+        }

+ 1 - 0
embedchain/models/data_type.py

@@ -12,3 +12,4 @@ class DataType(Enum):
     QNA_PAIR = "qna_pair"
     NOTION = "notion"
     CSV = "csv"
+    MDX = "mdx"