Bladeren bron

Embedchain json url support (#878)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 jaar geleden
bovenliggende
commit
5255a37c93
2 gewijzigde bestanden met toevoegingen van 84 en 8 verwijderingen
  1. 27 8
      embedchain/loaders/json.py
  2. 57 0
      tests/loaders/test_json.py

+ 27 - 8
embedchain/loaders/json.py

@@ -1,9 +1,14 @@
 import hashlib
 import json
 import os
+import re
+
+import requests
 
 from embedchain.loaders.base_loader import BaseLoader
 
+VALID_URL_PATTERN = "^https:\/\/[0-9A-z.]+.[0-9A-z.]+.[a-z]+\/.*\.json$"
+
 
 class JSONLoader(BaseLoader):
     @staticmethod
@@ -20,18 +25,32 @@ class JSONLoader(BaseLoader):
 
         loader = LLHBUBJSONLoader()
 
-        if not isinstance(content, str) and not os.path.isfile(content):
+        if not isinstance(content, str):
             print(f"Invaid content input. Provide the correct path to the json file saved locally in {content}")
 
         data = []
         data_content = []
 
-        with open(content, "r") as json_file:
-            json_data = json.load(json_file)
-            docs = loader.load_data(json_data)
-            for doc in docs:
-                doc_content = doc.text
-                data.append({"content": doc_content, "meta_data": {"url": content}})
-                data_content.append(doc_content)
+        # Load json data from various sources. TODO: add support for dictionary
+        if os.path.isfile(content):
+            with open(content, "r") as json_file:
+                json_data = json.load(json_file)
+        elif re.match(VALID_URL_PATTERN, content):
+            response = requests.get(content)
+            if response.status_code == 200:
+                json_data = response.json()
+            else:
+                raise ValueError(
+                    f"Loading data from the given url: {content} failed. \
+                    Make sure the url is working."
+                )
+        else:
+            raise ValueError(f"Invalid content to load json data from: {content}")
+
+        docs = loader.load_data(json_data)
+        for doc in docs:
+            doc_content = doc.text
+            data.append({"content": doc_content, "meta_data": {"url": content}})
+            data_content.append(doc_content)
         doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest()
         return {"doc_id": doc_id, "data": data}

+ 57 - 0
tests/loaders/test_json.py

@@ -1,5 +1,8 @@
 import hashlib
 
+import pytest
+from llama_index.readers.schema.base import Document
+
 from embedchain.loaders.json import JSONLoader
 
 
@@ -32,3 +35,57 @@ def test_load_data(mocker):
 
     expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
     assert result["doc_id"] == expected_doc_id
+
+
+def test_load_data_url(mocker):
+    content = "https://example.com/posts.json"
+
+    mocker.patch("os.path.isfile", return_value=False)  # Mocking os.path.isfile to simulate a URL case
+    mocker.patch(
+        "llama_hub.jsondata.base.JSONDataReader.load_data",
+        return_value=[Document(text="content1"), Document(text="content2")],
+    )
+
+    mock_response = mocker.Mock()
+    mock_response.status_code = 200
+    mock_response.json.return_value = {"document1": "content1", "document2": "content2"}
+
+    mocker.patch("requests.get", return_value=mock_response)
+
+    result = JSONLoader.load_data(content)
+
+    assert "doc_id" in result
+    assert "data" in result
+
+    expected_data = [
+        {"content": "content1", "meta_data": {"url": content}},
+        {"content": "content2", "meta_data": {"url": content}},
+    ]
+
+    assert result["data"] == expected_data
+
+    expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
+    assert result["doc_id"] == expected_doc_id
+
+
+def test_load_data_invalid_content(mocker):
+    mocker.patch("os.path.isfile", return_value=False)
+    mocker.patch("requests.get")
+
+    content = "123"
+
+    with pytest.raises(ValueError, match="Invalid content to load json data from"):
+        JSONLoader.load_data(content)
+
+
+def test_load_data_invalid_url(mocker):
+    mocker.patch("os.path.isfile", return_value=False)
+
+    mock_response = mocker.Mock()
+    mock_response.status_code = 404
+    mocker.patch("requests.get", return_value=mock_response)
+
+    content = "http://invalid-url.com/"
+
+    with pytest.raises(ValueError, match=f"Invalid content to load json data from: {content}"):
+        JSONLoader.load_data(content)