Browse Source

[Feature] JSON data loader support (#816)

Deven Patel 1 year ago
parent
commit
7641cba01d

+ 1 - 1
Makefile

@@ -38,7 +38,7 @@ lint:
 	poetry run ruff .
 
 test:
-	poetry run pytest
+	poetry run pytest $(file)
 
 coverage:
 	poetry run pytest --cov=$(PROJECT_NAME) --cov-report=xml

+ 1 - 0
README.md

@@ -45,6 +45,7 @@ Embedchain empowers you to create ChatGPT like apps, on your own dynamic dataset
 * Web page
 * Sitemap
 * Doc file
+* JSON file
 * Code documentation website loader
 * Notion and many more.
 

+ 22 - 0
embedchain/chunkers/json.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class JSONChunker(BaseChunker):
+    """Chunker for json."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 4 - 0
embedchain/data_formatter/data_formatter.py

@@ -2,6 +2,7 @@ from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.docs_site import DocsSiteChunker
 from embedchain.chunkers.docx_file import DocxFileChunker
 from embedchain.chunkers.images import ImagesChunker
+from embedchain.chunkers.json import JSONChunker
 from embedchain.chunkers.mdx import MdxChunker
 from embedchain.chunkers.notion import NotionChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
@@ -20,6 +21,7 @@ from embedchain.loaders.csv import CsvLoader
 from embedchain.loaders.docs_site_loader import DocsSiteLoader
 from embedchain.loaders.docx_file import DocxFileLoader
 from embedchain.loaders.images import ImagesLoader
+from embedchain.loaders.json import JSONLoader
 from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
 from embedchain.loaders.local_text import LocalTextLoader
 from embedchain.loaders.mdx import MdxLoader
@@ -75,6 +77,7 @@ class DataFormatter(JSONSerializable):
             DataType.CSV: CsvLoader,
             DataType.MDX: MdxLoader,
             DataType.IMAGES: ImagesLoader,
+            DataType.JSON: JSONLoader,
         }
         lazy_loaders = {DataType.NOTION}
         if data_type in loaders:
@@ -116,6 +119,7 @@ class DataFormatter(JSONSerializable):
             DataType.MDX: MdxChunker,
             DataType.IMAGES: ImagesChunker,
             DataType.XML: XmlChunker,
+            DataType.JSON: JSONChunker,
         }
         if data_type in chunker_classes:
             chunker_class: type = chunker_classes[data_type]

+ 23 - 0
embedchain/loaders/json.py

@@ -0,0 +1,23 @@
+import hashlib
+
+from langchain.document_loaders.json_loader import JSONLoader as LcJSONLoader
+
+from embedchain.loaders.base_loader import BaseLoader
+
+langchain_json_jq_schema = 'to_entries | map("\(.key): \(.value|tostring)") | .[]'
+
+
+class JSONLoader(BaseLoader):
+    @staticmethod
+    def load_data(content):
+        """Load a json file. Each data point is a key value pair."""
+        data = []
+        data_content = []
+        loader = LcJSONLoader(content, text_content=False, jq_schema=langchain_json_jq_schema)
+        docs = loader.load()
+        for doc in docs:
+            meta_data = doc.metadata
+            data.append({"content": doc.page_content, "meta_data": {"url": content, "row": meta_data["seq_num"]}})
+            data_content.append(doc.page_content)
+        doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest()
+        return {"doc_id": doc_id, "data": data}

+ 2 - 0
embedchain/models/data_type.py

@@ -25,6 +25,7 @@ class IndirectDataType(Enum):
     CSV = "csv"
     MDX = "mdx"
     IMAGES = "images"
+    JSON = "json"
 
 
 class SpecialDataType(Enum):
@@ -49,3 +50,4 @@ class DataType(Enum):
     MDX = IndirectDataType.MDX.value
     QNA_PAIR = SpecialDataType.QNA_PAIR.value
     IMAGES = IndirectDataType.IMAGES.value
+    JSON = IndirectDataType.JSON.value

+ 8 - 0
embedchain/utils.py

@@ -155,6 +155,10 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
             return DataType.DOCX
 
+        if url.path.endswith(".json"):
+            logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
+            return DataType.JSON
+
         if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
             # `docs_site` detection via path is not accepted for local filesystem URIs,
             # because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
@@ -194,6 +198,10 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
             return DataType.XML
 
+        if source.endswith(".json"):
+            logging.debug(f"Source of `{formatted_source}` detected as `json`.")
+            return DataType.JSON
+
         # If the source is a valid file, that's not detectable as a type, an error is raised.
         # It does not fallback to text.
         raise ValueError(

+ 5 - 3
pyproject.toml

@@ -120,9 +120,10 @@ torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
 ftfy = { version = "6.1.1", optional = true }
 regex = { version = "2023.8.8", optional = true }
 huggingface_hub = { version = "^0.17.3", optional = true }
-pymilvus = { version="2.3.1", optional = true }
-google-cloud-aiplatform = { version="^1.26.1", optional = true }
-replicate = { version="^0.15.4", optional = true }
+pymilvus = { version = "2.3.1", optional = true }
+google-cloud-aiplatform = { version = "^1.26.1", optional = true }
+replicate = { version = "^0.15.4", optional = true }
+jq = { version=">=1.6.0", optional = true}
 
 [tool.poetry.group.dev.dependencies]
 black = "^23.3.0"
@@ -163,6 +164,7 @@ dataloaders=[
     "docx2txt",
     "unstructured",
     "sentence-transformers",
+    "jq",
 ]
 vertexai = ["google-cloud-aiplatform"]
 llama2 = ["replicate"]

+ 2 - 0
tests/chunkers/test_chunkers.py

@@ -10,6 +10,7 @@ from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.web_page import WebPageChunker
 from embedchain.chunkers.xml import XmlChunker
 from embedchain.chunkers.youtube_video import YoutubeVideoChunker
+from embedchain.chunkers.json import JSONChunker
 from embedchain.config.add_config import ChunkerConfig
 
 chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
@@ -27,6 +28,7 @@ chunker_common_config = {
     WebPageChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len},
     XmlChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
     YoutubeVideoChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
+    JSONChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
 }
 
 

+ 31 - 0
tests/loaders/test_json.py

@@ -0,0 +1,31 @@
+import hashlib
+from unittest.mock import patch
+
+from langchain.docstore.document import Document
+from langchain.document_loaders.json_loader import JSONLoader as LcJSONLoader
+
+from embedchain.loaders.json import JSONLoader
+
+
+def test_load_data():
+    mock_document = [
+        Document(page_content="content1", metadata={"seq_num": 1}),
+        Document(page_content="content2", metadata={"seq_num": 2}),
+    ]
+    with patch.object(LcJSONLoader, "load", return_value=mock_document):
+        content = "temp.json"
+
+        result = JsonLoader.load_data(content)
+
+        assert "doc_id" in result
+        assert "data" in result
+
+        expected_data = [
+            {"content": "content1", "meta_data": {"url": content, "row": 1}},
+            {"content": "content2", "meta_data": {"url": content, "row": 2}},
+        ]
+
+        assert result["data"] == expected_data
+
+        expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
+        assert result["doc_id"] == expected_doc_id