Browse Source

[feat]: Add openapi spec data loader (#818)

Deven Patel 1 year ago
parent
commit
797bb567c6

+ 1 - 0
README.md

@@ -47,6 +47,7 @@ Embedchain empowers you to create ChatGPT like apps, on your own dynamic dataset
 * Doc file
 * JSON file
 * Code documentation website loader
+* OpenAPI specs
 * Notion
 * Unstructured file loader and many more
 

+ 36 - 0
docs/data-sources/json.mdx

@@ -0,0 +1,36 @@
+---
+title: '📃 JSON'
+---
+
+To add any json file, use the data_type as `json`. `json` allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
+
+```python
+import os
+
+from embedchain.apps.app import App
+
+os.environ["OPENAI_API_KEY"] = "openai_api_key"
+
+app = App()
+
+response = app.query("What is the net worth of Elon Musk as of October 2023?")
+
+print(response)
+"I'm sorry, but I don't have access to real-time information or future predictions. Therefore, I don't know the net worth of Elon Musk as of October 2023."
+
+source_id = app.add("temp.json")
+
+response = app.query("What is the net worth of Elon Musk as of October 2023?")
+
+print(response)
+"As of October 2023, Elon Musk's net worth is $255.2 billion."
+```
+
+```temp.json
+{
+    "question": "What is your net worth, Elon Musk?",
+    "answer": "As of October 2023, Elon Musk's net worth is $255.2 billion, making him one of the wealthiest individuals in the world."
+}
+```
+
+

+ 23 - 0
docs/data-sources/openapi.mdx

@@ -0,0 +1,23 @@
+---
+title: 🙌 OpenAPI
+---
+
+To add any OpenAPI spec yaml file (currently the json file will be detected as JSON data type), use the data_type as 'openapi'. 'openapi' allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
+
+```python
+from embedchain.apps.app import App
+import os
+
+os.environ["OPENAI_API_KEY"] = "sk-xxx"
+
+app = App()
+
+app.add("https://github.com/openai/openai-openapi/blob/master/openapi.yaml", data_type="openapi")
+# Or add using the local file path
+# app.add("configs/openai_openapi.yaml", data_type="openapi")
+
+response = app.query("What can OpenAI API endpoint do? Can you list the things it can learn from?")
+# Answer: The OpenAI API endpoint allows users to interact with OpenAI's models and perform various tasks such as generating text, answering questions, summarizing documents, translating languages, and more. The specific capabilities and tasks that the API can learn from may vary depending on the models and features provided by OpenAI. For more detailed information, it is recommended to refer to the OpenAI API documentation at https://platform.openai.com/docs/api-reference.
+```
+
+NOTE: The yaml file added to the App must have the required OpenAPI fields otherwise the adding OpenAPI spec will fail. Please refer to [OpenAPI Spec Doc](https://spec.openapis.org/oas/v3.1.0)

+ 2 - 0
docs/data-sources/overview.mdx

@@ -6,6 +6,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
 
 <CardGroup cols={4}>
   <Card title="📊 csv" href="/data-sources/csv"></Card>
+  <Card title="📃 JSON" href="/data-sources/json"></Card>
   <Card title="📚🌐 docs site" href="/data-sources/docs-site"></Card>
   <Card title="📄 docx" href="/data-sources/docx"></Card>
   <Card title="📝 mdx" href="/data-sources/mdx"></Card>
@@ -16,6 +17,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
   <Card title="📝 text" href="/data-sources/text"></Card>
   <Card title="🌐📄 web page" href="/data-sources/web-page"></Card>
   <Card title="🧾 xml" href="/data-sources/xml"></Card>
+  <Card title="🙌 OpenApi" href="/data-sources/openapi"></Card>
   <Card title="🎥📺 youtube video" href="/data-sources/youtube-video"></Card>
 </CardGroup>
 

+ 2 - 0
docs/mint.json

@@ -46,6 +46,7 @@
           "group": "Supported data sources",
           "pages": [
             "data-sources/csv",
+            "data-sources/json",
             "data-sources/docs-site",
             "data-sources/docx",
             "data-sources/mdx",
@@ -55,6 +56,7 @@
             "data-sources/sitemap",
             "data-sources/text",
             "data-sources/web-page",
+            "data-sources/openapi",
             "data-sources/youtube-video"
           ]
         },

+ 18 - 0
embedchain/chunkers/openapi.py

@@ -0,0 +1,18 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+
+
+class OpenAPIChunker(BaseChunker):
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 4 - 0
embedchain/data_formatter/data_formatter.py

@@ -5,6 +5,7 @@ from embedchain.chunkers.images import ImagesChunker
 from embedchain.chunkers.json import JSONChunker
 from embedchain.chunkers.mdx import MdxChunker
 from embedchain.chunkers.notion import NotionChunker
+from embedchain.chunkers.openapi import OpenAPIChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
 from embedchain.chunkers.sitemap import SitemapChunker
@@ -26,6 +27,7 @@ from embedchain.loaders.json import JSONLoader
 from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
 from embedchain.loaders.local_text import LocalTextLoader
 from embedchain.loaders.mdx import MdxLoader
+from embedchain.loaders.openapi import OpenAPILoader
 from embedchain.loaders.pdf_file import PdfFileLoader
 from embedchain.loaders.sitemap import SitemapLoader
 from embedchain.loaders.unstructured_file import UnstructuredLoader
@@ -81,6 +83,7 @@ class DataFormatter(JSONSerializable):
             DataType.IMAGES: ImagesLoader,
             DataType.UNSTRUCTURED: UnstructuredLoader,
             DataType.JSON: JSONLoader,
+            DataType.OPENAPI: OpenAPILoader,
         }
         lazy_loaders = {DataType.NOTION}
         if data_type in loaders:
@@ -124,6 +127,7 @@ class DataFormatter(JSONSerializable):
             DataType.XML: XmlChunker,
             DataType.UNSTRUCTURED: UnstructuredFileChunker,
             DataType.JSON: JSONChunker,
+            DataType.OPENAPI: OpenAPIChunker,
         }
         if data_type in chunker_classes:
             chunker_class: type = chunker_classes[data_type]

+ 42 - 0
embedchain/loaders/openapi.py

@@ -0,0 +1,42 @@
+import hashlib
+from io import StringIO
+from urllib.parse import urlparse
+
+import requests
+import yaml
+
+from embedchain.loaders.base_loader import BaseLoader
+
+
+class OpenAPILoader(BaseLoader):
+    @staticmethod
+    def _get_file_content(content):
+        url = urlparse(content)
+        if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
+            raise ValueError("Not a valid URL.")
+
+        if url.scheme in ["http", "https"]:
+            response = requests.get(content)
+            response.raise_for_status()
+            return StringIO(response.text)
+        elif url.scheme == "file":
+            path = url.path
+            return open(path)
+        else:
+            return open(content)
+
+    @staticmethod
+    def load_data(content):
+        """Load yaml file of openapi. Each pair is a document."""
+        data = []
+        file_path = content
+        data_content = []
+        with OpenAPILoader._get_file_content(content=content) as file:
+            yaml_data = yaml.load(file, Loader=yaml.Loader)
+            for i, (key, value) in enumerate(yaml_data.items()):
+                string_data = f"{key}: {value}"
+                meta_data = {"url": file_path, "row": i + 1}
+                data.append({"content": string_data, "meta_data": meta_data})
+                data_content.append(string_data)
+        doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest()
+        return {"doc_id": doc_id, "data": data}

+ 2 - 0
embedchain/models/data_type.py

@@ -27,6 +27,7 @@ class IndirectDataType(Enum):
     IMAGES = "images"
     UNSTRUCTURED = "unstructured"
     JSON = "json"
+    OPENAPI = "openapi"
 
 
 class SpecialDataType(Enum):
@@ -53,3 +54,4 @@ class DataType(Enum):
     IMAGES = IndirectDataType.IMAGES.value
     UNSTRUCTURED = IndirectDataType.UNSTRUCTURED.value
     JSON = IndirectDataType.JSON.value
+    OPENAPI = IndirectDataType.OPENAPI.value

+ 48 - 0
embedchain/utils.py

@@ -115,6 +115,13 @@ def detect_datatype(source: Any) -> DataType:
     """
     from urllib.parse import urlparse
 
+    import requests
+    import yaml
+
+    def is_openapi_yaml(yaml_content):
+        # currently the following two fields are required in openapi spec yaml config
+        return "openapi" in yaml_content and "info" in yaml_content
+
     try:
         if not isinstance(source, str):
             raise ValueError("Source is not a string and thus cannot be a URL.")
@@ -155,6 +162,31 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
             return DataType.DOCX
 
+        if url.path.endswith(".yaml"):
+            try:
+                response = requests.get(source)
+                response.raise_for_status()
+                try:
+                    yaml_content = yaml.safe_load(response.text)
+                except yaml.YAMLError as exc:
+                    logging.error(f"Error parsing YAML: {exc}")
+                    raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
+
+                if is_openapi_yaml(yaml_content):
+                    logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
+                    return DataType.OPENAPI
+                else:
+                    logging.error(
+                        f"Source of `{formatted_source}` does not contain all the required \
+                        fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
+                    )
+                    raise TypeError(
+                        "Not a valid data type. Check 'https://spec.openapis.org/oas/v3.1.0', \
+                        make sure you have all the required fields in YAML config data"
+                    )
+            except requests.exceptions.RequestException as e:
+                logging.error(f"Error fetching URL {formatted_source}: {e}")
+
         if url.path.endswith(".json"):
             logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
             return DataType.JSON
@@ -198,6 +230,22 @@ def detect_datatype(source: Any) -> DataType:
             logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
             return DataType.XML
 
+        if source.endswith(".yaml"):
+            with open(source, "r") as file:
+                yaml_content = yaml.safe_load(file)
+                if is_openapi_yaml(yaml_content):
+                    logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
+                    return DataType.OPENAPI
+                else:
+                    logging.error(
+                        f"Source of `{formatted_source}` does not contain all the required \
+                                  fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
+                    )
+                    raise ValueError(
+                        "Invalid YAML data. Check 'https://spec.openapis.org/oas/v3.1.0', \
+                        make sure to add all the required params"
+                    )
+
         if source.endswith(".json"):
             logging.debug(f"Source of `{formatted_source}` detected as `json`.")
             return DataType.JSON

+ 2 - 0
tests/chunkers/test_chunkers.py

@@ -3,6 +3,7 @@ from embedchain.chunkers.docx_file import DocxFileChunker
 from embedchain.chunkers.json import JSONChunker
 from embedchain.chunkers.mdx import MdxChunker
 from embedchain.chunkers.notion import NotionChunker
+from embedchain.chunkers.openapi import OpenAPIChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
 from embedchain.chunkers.sitemap import SitemapChunker
@@ -29,6 +30,7 @@ chunker_common_config = {
     XmlChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
     YoutubeVideoChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
     JSONChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+    OpenAPIChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
 }
 
 

+ 6 - 0
tests/embedchain/test_utils.py

@@ -39,6 +39,12 @@ class TestApp(unittest.TestCase):
     def test_detect_datatype_local_docx(self):
         self.assertEqual(detect_datatype("file:///home/user/document.docx"), DataType.DOCX)
 
+    def test_detect_data_type_json(self):
+        self.assertEqual(detect_datatype("https://www.example.com/data.json"), DataType.JSON)
+
+    def test_detect_data_type_local_json(self):
+        self.assertEqual(detect_datatype("file:///home/user/data.json"), DataType.JSON)
+
     @patch("os.path.isfile")
     def test_detect_datatype_regular_filesystem_docx(self, mock_isfile):
         with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:

+ 26 - 0
tests/loaders/test_openapi.py

@@ -0,0 +1,26 @@
+import pytest
+
+from embedchain.loaders.openapi import OpenAPILoader
+
+
+@pytest.fixture
+def openapi_loader():
+    return OpenAPILoader()
+
+
+def test_load_data(openapi_loader, mocker):
+    mocker.patch("builtins.open", mocker.mock_open(read_data="key1: value1\nkey2: value2"))
+
+    mocker.patch("hashlib.sha256", return_value=mocker.Mock(hexdigest=lambda: "mock_hash"))
+
+    file_path = "configs/openai_openapi.yaml"
+    result = openapi_loader.load_data(file_path)
+
+    expected_doc_id = "mock_hash"
+    expected_data = [
+        {"content": "key1: value1", "meta_data": {"url": file_path, "row": 1}},
+        {"content": "key2: value2", "meta_data": {"url": file_path, "row": 2}},
+    ]
+
+    assert result["doc_id"] == expected_doc_id
+    assert result["data"] == expected_data