Sfoglia il codice sorgente

[Feature Improvement] Update JSON Loader to support loading data from more sources (#898)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 anno fa
parent
commit
53037b5ed8

+ 21 - 4
docs/data-sources/json.mdx

@@ -2,9 +2,26 @@
 title: '📃 JSON'
 ---
 
-To add any json file, use the data_type as `json`. `json` allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
+To add any json file, use the data_type as `json`. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
 
-```python
+Here are the supported sources for loading `json`:
+```
+1. URL - valid url to json file that ends with ".json" extension.
+2. Local file - valid url to local json file that ends with ".json" extension.
+3. String - valid json string (e.g. - app.add('{"foo": "bar"}'))
+```
+
+If you would like to add other data structures (e.x. list, dict etc.), do:
+```
+    import json
+    a = {"foo": "bar"}
+    valid_json_string_data = json.dumps(a, indent=0)
+
+    b = [{"foo": "bar"}]
+    valid_json_string_data = json.dumps(b, indent=0)
+```
+Example:
+```
 import os
 
 from embedchain.apps.app import App
@@ -25,8 +42,8 @@ response = app.query("What is the net worth of Elon Musk as of October 2023?")
 print(response)
 "As of October 2023, Elon Musk's net worth is $255.2 billion."
 ```
-
-```temp.json
+temp.json
+```
 {
     "question": "What is your net worth, Elon Musk?",
     "answer": "As of October 2023, Elon Musk's net worth is $255.2 billion, making him one of the wealthiest individuals in the world."

+ 25 - 1
embedchain/embedchain.py

@@ -20,7 +20,7 @@ from embedchain.loaders.base_loader import BaseLoader
 from embedchain.models.data_type import (DataType, DirectDataType,
                                          IndirectDataType, SpecialDataType)
 from embedchain.telemetry.posthog import AnonymousTelemetry
-from embedchain.utils import detect_datatype
+from embedchain.utils import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB
 
 load_dotenv()
@@ -175,11 +175,27 @@ class EmbedChain(JSONSerializable):
         if data_type:
             try:
                 data_type = DataType(data_type)
+                if data_type == DataType.JSON:
+                    if isinstance(source, str):
+                        if not is_valid_json_string(source):
+                            raise ValueError(
+                                f"Invalid json input: {source}",
+                                "Provide the correct JSON formatted source, \
+                                    refer `https://docs.embedchain.ai/data-sources/json`",
+                            )
+                    elif not isinstance(source, str):
+                        raise ValueError(
+                            "Invaid content input. \
+                            If you want to upload (list, dict, etc.), do \
+                                `json.dump(data, indent=0)` and add the stringified JSON. \
+                                    Check - `https://docs.embedchain.ai/data-sources/json`"
+                        )
             except ValueError:
                 raise ValueError(
                     f"Invalid data_type: '{data_type}'.",
                     f"Please use one of the following: {[data_type.value for data_type in DataType]}",
                 ) from None
+
         if not data_type:
             data_type = detect_datatype(source)
 
@@ -287,6 +303,10 @@ class EmbedChain(JSONSerializable):
             # These types have a indirect source reference
             # As long as the reference is the same, they can be updated.
             where = {"url": src}
+            if chunker.data_type == DataType.JSON and is_valid_json_string(src):
+                url = hashlib.sha256((src).encode("utf-8")).hexdigest()
+                where = {"url": url}
+
             if self.config.id is not None:
                 where.update({"app_id": self.config.id})
 
@@ -368,6 +388,10 @@ class EmbedChain(JSONSerializable):
 
         # get existing ids, and discard doc if any common id exist.
         where = {"url": src}
+        if chunker.data_type == DataType.JSON and is_valid_json_string(src):
+            url = hashlib.sha256((src).encode("utf-8")).hexdigest()
+            where = {"url": url}
+
         # if data type is qna_pair, we check for question
         if chunker.data_type == DataType.QNA_PAIR:
             where = {"question": src[0]}

+ 21 - 13
embedchain/loaders/json.py

@@ -6,33 +6,37 @@ import re
 import requests
 
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.utils import clean_string
+from embedchain.utils import clean_string, is_valid_json_string
 
 VALID_URL_PATTERN = "^https:\/\/[0-9A-z.]+.[0-9A-z.]+.[a-z]+\/.*\.json$"
 
 
 class JSONLoader(BaseLoader):
     @staticmethod
-    def load_data(content):
-        """Load a json file. Each data point is a key value pair."""
+    def _get_llama_hub_loader():
         try:
             from llama_hub.jsondata.base import \
-                JSONDataReader as LLHBUBJSONLoader
-        except ImportError:
+                JSONDataReader as LLHUBJSONLoader
+        except ImportError as e:
             raise Exception(
-                f"Couldn't import the required packages to load {content}, \
-                Do `pip install --upgrade 'embedchain[json]`"
+                f"Failed to install required packages: {e}, \
+                install them using `pip install --upgrade 'embedchain[json]`"
             )
 
-        loader = LLHBUBJSONLoader()
+        return LLHUBJSONLoader()
+
+    @staticmethod
+    def load_data(content):
+        """Load a json file. Each data point is a key value pair."""
 
-        if not isinstance(content, str):
-            print(f"Invaid content input. Provide the correct path to the json file saved locally in {content}")
+        loader = JSONLoader._get_llama_hub_loader()
 
         data = []
         data_content = []
 
-        # Load json data from various sources. TODO: add support for dictionary
+        content_url_str = content
+
+        # Load json data from various sources.
         if os.path.isfile(content):
             with open(content, "r", encoding="utf-8") as json_file:
                 json_data = json.load(json_file)
@@ -45,13 +49,17 @@ class JSONLoader(BaseLoader):
                     f"Loading data from the given url: {content} failed. \
                     Make sure the url is working."
                 )
+        elif is_valid_json_string(content):
+            json_data = content
+            content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
         else:
             raise ValueError(f"Invalid content to load json data from: {content}")
 
         docs = loader.load_data(json_data)
         for doc in docs:
             doc_content = clean_string(doc.text)
-            data.append({"content": doc_content, "meta_data": {"url": content}})
+            data.append({"content": doc_content, "meta_data": {"url": content_url_str}})
             data_content.append(doc_content)
-        doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest()
+
+        doc_id = hashlib.sha256((content_url_str + ", ".join(data_content)).encode()).hexdigest()
         return {"doc_id": doc_id, "data": data}

+ 19 - 0
embedchain/utils.py

@@ -1,3 +1,4 @@
+import json
 import logging
 import os
 import re
@@ -261,6 +262,24 @@ def detect_datatype(source: Any) -> DataType:
 
         # TODO: check if source is gmail query
 
+        # check if the source is valid json string
+        if is_valid_json_string(source):
+            logging.debug(f"Source of `{formatted_source}` detected as `json`.")
+            return DataType.JSON
+
         # Use text as final fallback.
         logging.debug(f"Source of `{formatted_source}` detected as `text`.")
         return DataType.TEXT
+
+
+# check if the source is valid json string
+def is_valid_json_string(source: str):
+    try:
+        _ = json.loads(source)
+        return True
+    except json.JSONDecodeError:
+        logging.error(
+            "Insert valid string format of JSON. \
+            Check the docs to see the supported formats - `https://docs.embedchain.ai/data-sources/json`"
+        )
+        return False

+ 59 - 55
tests/embedchain/test_embedchain.py

@@ -1,61 +1,65 @@
 import os
-import unittest
-from unittest.mock import patch
+
+import pytest
+from chromadb.api.models.Collection import Collection
 
 from embedchain import App
 from embedchain.config import AppConfig, ChromaDbConfig
+from embedchain.embedchain import EmbedChain
+from embedchain.llm.base import BaseLlm
+
+os.environ["OPENAI_API_KEY"] = "test-api-key"
+
+
+@pytest.fixture
+def app_instance():
+    config = AppConfig(log_level="DEBUG", collect_metrics=False)
+    return App(config)
+
+
+def test_whole_app(app_instance, mocker):
+    knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
+
+    mocker.patch.object(EmbedChain, "add")
+    mocker.patch.object(EmbedChain, "retrieve_from_database")
+    mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
+    mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
+    mocker.patch.object(BaseLlm, "generate_prompt")
+
+    app_instance.add(knowledge, data_type="text")
+    app_instance.query("What text did I give you?")
+    app_instance.chat("What text did I give you?")
+
+    assert BaseLlm.generate_prompt.call_count == 2
+    app_instance.reset()
+
+
+def test_add_after_reset(app_instance, mocker):
+    config = AppConfig(log_level="DEBUG", collect_metrics=False)
+    chroma_config = {"allow_reset": True}
+
+    app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
+    app_instance.reset()
+
+    app_instance.db.client.heartbeat()
+
+    mocker.patch.object(Collection, "add")
+
+    app_instance.db.collection.add(
+        embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
+        metadatas=[
+            {"chapter": "3", "verse": "16"},
+            {"chapter": "3", "verse": "5"},
+            {"chapter": "29", "verse": "11"},
+        ],
+        ids=["id1", "id2", "id3"],
+    )
+
+    app_instance.reset()
+
 
+def test_add_with_incorrect_content(app_instance, mocker):
+    content = [{"foo": "bar"}]
 
-class TestChromaDbHostsLoglevel(unittest.TestCase):
-    os.environ["OPENAI_API_KEY"] = "test_key"
-
-    @patch("chromadb.api.models.Collection.Collection.add")
-    @patch("embedchain.embedchain.EmbedChain.retrieve_from_database")
-    @patch("embedchain.llm.base.BaseLlm.get_answer_from_llm")
-    @patch("embedchain.llm.base.BaseLlm.get_llm_model_answer")
-    def test_whole_app(
-        self,
-        _mock_add,
-        _mock_ec_retrieve_from_database,
-        _mock_get_answer_from_llm,
-        mock_ec_get_llm_model_answer,
-    ):
-        """
-        Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
-        """
-        config = AppConfig(log_level="DEBUG", collect_metrics=False)
-
-        app = App(config)
-
-        knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
-
-        app.add(knowledge, data_type="text")
-
-        app.query("What text did I give you?")
-        app.chat("What text did I give you?")
-
-        self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge])
-
-    def test_add_after_reset(self):
-        """
-        Test if the `App` instance is correctly reconstructed after a reset.
-        """
-        config = AppConfig(log_level="DEBUG", collect_metrics=False)
-        chroma_config = {"allow_reset": True}
-        app = App(config=config, db_config=ChromaDbConfig(**chroma_config))
-        app.reset()
-
-        # Make sure the client is still healthy
-        app.db.client.heartbeat()
-        # Make sure the collection exists, and can be added to
-        app.db.collection.add(
-            embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
-            metadatas=[
-                {"chapter": "3", "verse": "16"},
-                {"chapter": "3", "verse": "5"},
-                {"chapter": "29", "verse": "11"},
-            ],
-            ids=["id1", "id2", "id3"],
-        )
-
-        app.reset()
+    with pytest.raises(ValueError):
+        app_instance.add(content, data_type="json")

+ 30 - 3
tests/loaders/test_json.py

@@ -40,7 +40,7 @@ def test_load_data(mocker):
 def test_load_data_url(mocker):
     content = "https://example.com/posts.json"
 
-    mocker.patch("os.path.isfile", return_value=False)  # Mocking os.path.isfile to simulate a URL case
+    mocker.patch("os.path.isfile", return_value=False)
     mocker.patch(
         "llama_hub.jsondata.base.JSONDataReader.load_data",
         return_value=[Document(text="content1"), Document(text="content2")],
@@ -68,11 +68,11 @@ def test_load_data_url(mocker):
     assert result["doc_id"] == expected_doc_id
 
 
-def test_load_data_invalid_content(mocker):
+def test_load_data_invalid_string_content(mocker):
     mocker.patch("os.path.isfile", return_value=False)
     mocker.patch("requests.get")
 
-    content = "123"
+    content = "123: 345}"
 
     with pytest.raises(ValueError, match="Invalid content to load json data from"):
         JSONLoader.load_data(content)
@@ -89,3 +89,30 @@ def test_load_data_invalid_url(mocker):
 
     with pytest.raises(ValueError, match=f"Invalid content to load json data from: {content}"):
         JSONLoader.load_data(content)
+
+
+def test_load_data_from_json_string(mocker):
+    content = '{"foo": "bar"}'
+
+    content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
+
+    mocker.patch("os.path.isfile", return_value=False)
+    mocker.patch(
+        "llama_hub.jsondata.base.JSONDataReader.load_data",
+        return_value=[Document(text="content1"), Document(text="content2")],
+    )
+
+    result = JSONLoader.load_data(content)
+
+    assert "doc_id" in result
+    assert "data" in result
+
+    expected_data = [
+        {"content": "content1", "meta_data": {"url": content_url_str}},
+        {"content": "content2", "meta_data": {"url": content_url_str}},
+    ]
+
+    assert result["data"] == expected_data
+
+    expected_doc_id = hashlib.sha256((content_url_str + ", ".join(["content1", "content2"])).encode()).hexdigest()
+    assert result["doc_id"] == expected_doc_id