浏览代码

Feat/serialize deserialize (#508)

Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
cachho 1 年之前
父节点
当前提交
0d4ad07d7b
共有 42 个文件被更改,包括 345 次插入8 次删除
  1. 2 0
      embedchain/apps/App.py
  2. 2 0
      embedchain/apps/CustomApp.py
  3. 2 0
      embedchain/apps/OpenSourceApp.py
  4. 4 0
      embedchain/apps/PersonApp.py
  5. 4 1
      embedchain/bots/base.py
  6. 2 0
      embedchain/bots/poe.py
  7. 3 0
      embedchain/bots/whatsapp.py
  8. 2 1
      embedchain/chunkers/base_chunker.py
  9. 2 0
      embedchain/chunkers/docs_site.py
  10. 2 0
      embedchain/chunkers/docx_file.py
  11. 2 0
      embedchain/chunkers/notion.py
  12. 2 0
      embedchain/chunkers/pdf_file.py
  13. 2 0
      embedchain/chunkers/qna_pair.py
  14. 2 0
      embedchain/chunkers/text.py
  15. 2 0
      embedchain/chunkers/web_page.py
  16. 2 0
      embedchain/chunkers/youtube_video.py
  17. 4 0
      embedchain/config/AddConfig.py
  18. 4 1
      embedchain/config/BaseConfig.py
  19. 2 0
      embedchain/config/ChatConfig.py
  20. 2 0
      embedchain/config/QueryConfig.py
  21. 3 0
      embedchain/config/apps/AppConfig.py
  22. 2 1
      embedchain/config/apps/BaseAppConfig.py
  23. 2 0
      embedchain/config/apps/CustomAppConfig.py
  24. 3 0
      embedchain/config/apps/OpenSourceAppConfig.py
  25. 2 0
      embedchain/config/vectordbs/ElasticsearchDBConfig.py
  26. 2 1
      embedchain/data_formatter/data_formatter.py
  27. 2 1
      embedchain/embedchain.py
  28. 180 0
      embedchain/helper_classes/json_serializable.py
  29. 4 1
      embedchain/loaders/base_loader.py
  30. 2 0
      embedchain/loaders/docs_site_loader.py
  31. 2 0
      embedchain/loaders/docx_file.py
  32. 2 0
      embedchain/loaders/local_qna_pair.py
  33. 2 0
      embedchain/loaders/local_text.py
  34. 2 0
      embedchain/loaders/notion.py
  35. 2 0
      embedchain/loaders/pdf_file.py
  36. 2 0
      embedchain/loaders/sitemap.py
  37. 2 0
      embedchain/loaders/web_page.py
  38. 2 0
      embedchain/loaders/youtube_video.py
  39. 4 1
      embedchain/vectordb/base_vector_db.py
  40. 2 0
      embedchain/vectordb/chroma_db.py
  41. 2 0
      embedchain/vectordb/elasticsearch_db.py
  42. 70 0
      tests/helper_classes/test_json_serializable.py

+ 2 - 0
embedchain/apps/App.py

@@ -4,8 +4,10 @@ import openai
 
 from embedchain.config import AppConfig, ChatConfig
 from embedchain.embedchain import EmbedChain
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class App(EmbedChain):
     """
     The EmbedChain app.

+ 2 - 0
embedchain/apps/CustomApp.py

@@ -5,9 +5,11 @@ from langchain.schema import BaseMessage
 
 from embedchain.config import ChatConfig, CustomAppConfig
 from embedchain.embedchain import EmbedChain
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.models import Providers
 
 
+@register_deserializable
 class CustomApp(EmbedChain):
     """
     The custom EmbedChain app.

+ 2 - 0
embedchain/apps/OpenSourceApp.py

@@ -3,10 +3,12 @@ from typing import Iterable, Union, Optional
 
 from embedchain.config import ChatConfig, OpenSourceAppConfig
 from embedchain.embedchain import EmbedChain
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 gpt4all_model = None
 
 
+@register_deserializable
 class OpenSourceApp(EmbedChain):
     """
     The OpenSource app.

+ 4 - 0
embedchain/apps/PersonApp.py

@@ -6,8 +6,10 @@ from embedchain.config import ChatConfig, QueryConfig
 from embedchain.config.apps.BaseAppConfig import BaseAppConfig
 from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
                                            DEFAULT_PROMPT_WITH_HISTORY)
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class EmbedChainPersonApp:
     """
     Base class to create a person bot.
@@ -50,6 +52,7 @@ class EmbedChainPersonApp:
         return config
 
 
+@register_deserializable
 class PersonApp(EmbedChainPersonApp, App):
     """
     The Person app.
@@ -65,6 +68,7 @@ class PersonApp(EmbedChainPersonApp, App):
         return super().chat(input_query, config, dry_run)
 
 
+@register_deserializable
 class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
     """
     The Person app.

+ 4 - 1
embedchain/bots/base.py

@@ -1,9 +1,12 @@
 from embedchain import CustomApp
 from embedchain.config import AddConfig, CustomAppConfig, QueryConfig
+from embedchain.helper_classes.json_serializable import (
+    JSONSerializable, register_deserializable)
 from embedchain.models import EmbeddingFunctions, Providers
 
 
-class BaseBot:
+@register_deserializable
+class BaseBot(JSONSerializable):
     def __init__(self, app_config=None):
         if app_config is None:
             app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI)

+ 2 - 0
embedchain/bots/poe.py

@@ -6,10 +6,12 @@ from typing import List, Optional
 from fastapi_poe import PoeBot, run
 
 from embedchain.config import QueryConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 from .base import BaseBot
 
 
+@register_deserializable
 class EcPoeBot(BaseBot, PoeBot):
     def __init__(self):
         self.history_length = 5

+ 3 - 0
embedchain/bots/whatsapp.py

@@ -6,9 +6,12 @@ import sys
 from flask import Flask, request
 from twilio.twiml.messaging_response import MessagingResponse
 
+from embedchain.helper_classes.json_serializable import register_deserializable
+
 from .base import BaseBot
 
 
+@register_deserializable
 class WhatsAppBot(BaseBot):
     def __init__(self):
         super().__init__()

+ 2 - 1
embedchain/chunkers/base_chunker.py

@@ -1,9 +1,10 @@
 import hashlib
 
+from embedchain.helper_classes.json_serializable import JSONSerializable
 from embedchain.models.data_type import DataType
 
 
-class BaseChunker:
+class BaseChunker(JSONSerializable):
     def __init__(self, text_splitter):
         """Initialize the chunker."""
         self.text_splitter = text_splitter

+ 2 - 0
embedchain/chunkers/docs_site.py

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class DocsSiteChunker(BaseChunker):
     """Chunker for code docs site."""
 

+ 2 - 0
embedchain/chunkers/docx_file.py

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class DocxFileChunker(BaseChunker):
     """Chunker for .docx file."""
 

+ 2 - 0
embedchain/chunkers/notion.py

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class NotionChunker(BaseChunker):
     """Chunker for notion."""
 

+ 2 - 0
embedchain/chunkers/pdf_file.py

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class PdfFileChunker(BaseChunker):
     """Chunker for PDF file."""
 

+ 2 - 0
embedchain/chunkers/qna_pair.py

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class QnaPairChunker(BaseChunker):
     """Chunker for QnA pair."""
 

+ 2 - 0
embedchain/chunkers/text.py

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class TextChunker(BaseChunker):
     """Chunker for text."""
 

+ 2 - 0
embedchain/chunkers/web_page.py

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class WebPageChunker(BaseChunker):
     """Chunker for web page."""
 

+ 2 - 0
embedchain/chunkers/youtube_video.py

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class YoutubeVideoChunker(BaseChunker):
     """Chunker for Youtube video."""
 

+ 4 - 0
embedchain/config/AddConfig.py

@@ -1,8 +1,10 @@
 from typing import Callable, Optional
 
 from embedchain.config.BaseConfig import BaseConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class ChunkerConfig(BaseConfig):
     """
     Config for the chunker used in `add` method
@@ -19,6 +21,7 @@ class ChunkerConfig(BaseConfig):
         self.length_function = length_function if length_function else len
 
 
+@register_deserializable
 class LoaderConfig(BaseConfig):
     """
     Config for the chunker used in `add` method
@@ -28,6 +31,7 @@ class LoaderConfig(BaseConfig):
         pass
 
 
+@register_deserializable
 class AddConfig(BaseConfig):
     """
     Config for the `add` method.

+ 4 - 1
embedchain/config/BaseConfig.py

@@ -1,4 +1,7 @@
-class BaseConfig:
+from embedchain.helper_classes.json_serializable import JSONSerializable
+
+
+class BaseConfig(JSONSerializable):
     """
     Base config.
     """

+ 2 - 0
embedchain/config/ChatConfig.py

@@ -2,6 +2,7 @@ from string import Template
 from typing import Optional
 
 from embedchain.config.QueryConfig import QueryConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 DEFAULT_PROMPT = """
   You are a chatbot having a conversation with a human. You are given chat
@@ -20,6 +21,7 @@ DEFAULT_PROMPT = """
 DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
 
 
+@register_deserializable
 class ChatConfig(QueryConfig):
     """
     Config for the `chat` method, inherits from `QueryConfig`.

+ 2 - 0
embedchain/config/QueryConfig.py

@@ -3,6 +3,7 @@ from string import Template
 from typing import Optional
 
 from embedchain.config.BaseConfig import BaseConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 DEFAULT_PROMPT = """
   Use the following pieces of context to answer the query at the end.
@@ -48,6 +49,7 @@ context_re = re.compile(r"\$\{*context\}*")
 history_re = re.compile(r"\$\{*history\}*")
 
 
+@register_deserializable
 class QueryConfig(BaseConfig):
     """
     Config for the `query` method.

+ 3 - 0
embedchain/config/apps/AppConfig.py

@@ -9,9 +9,12 @@ except RuntimeError:
     use_pysqlite3()
     from chromadb.utils import embedding_functions
 
+from embedchain.helper_classes.json_serializable import register_deserializable
+
 from .BaseAppConfig import BaseAppConfig
 
 
+@register_deserializable
 class AppConfig(BaseAppConfig):
     """
     Config to initialize an embedchain custom `App` instance, with extra config options.

+ 2 - 1
embedchain/config/apps/BaseAppConfig.py

@@ -2,10 +2,11 @@ import logging
 
 from embedchain.config.BaseConfig import BaseConfig
 from embedchain.config.vectordbs import ElasticsearchDBConfig
+from embedchain.helper_classes.json_serializable import JSONSerializable
 from embedchain.models import VectorDatabases, VectorDimensions
 
 
-class BaseAppConfig(BaseConfig):
+class BaseAppConfig(BaseConfig, JSONSerializable):
     """
     Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
     """

+ 2 - 0
embedchain/config/apps/CustomAppConfig.py

@@ -4,6 +4,7 @@ from chromadb.api.types import Documents, Embeddings
 from dotenv import load_dotenv
 
 from embedchain.config.vectordbs import ElasticsearchDBConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases,
                                VectorDimensions)
 
@@ -12,6 +13,7 @@ from .BaseAppConfig import BaseAppConfig
 load_dotenv()
 
 
+@register_deserializable
 class CustomAppConfig(BaseAppConfig):
     """
     Config to initialize an embedchain custom `App` instance, with extra config options.

+ 3 - 0
embedchain/config/apps/OpenSourceAppConfig.py

@@ -2,9 +2,12 @@ from typing import Optional
 
 from chromadb.utils import embedding_functions
 
+from embedchain.helper_classes.json_serializable import register_deserializable
+
 from .BaseAppConfig import BaseAppConfig
 
 
+@register_deserializable
 class OpenSourceAppConfig(BaseAppConfig):
     """
     Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.

+ 2 - 0
embedchain/config/vectordbs/ElasticsearchDBConfig.py

@@ -1,8 +1,10 @@
 from typing import Dict, List, Union
 
 from embedchain.config.BaseConfig import BaseConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 
 
+@register_deserializable
 class ElasticsearchDBConfig(BaseConfig):
     """
     Config to initialize an elasticsearch client.

+ 2 - 1
embedchain/data_formatter/data_formatter.py

@@ -7,6 +7,7 @@ from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.web_page import WebPageChunker
 from embedchain.chunkers.youtube_video import YoutubeVideoChunker
 from embedchain.config import AddConfig
+from embedchain.helper_classes.json_serializable import JSONSerializable
 from embedchain.loaders.docs_site_loader import DocsSiteLoader
 from embedchain.loaders.docx_file import DocxFileLoader
 from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
@@ -18,7 +19,7 @@ from embedchain.loaders.youtube_video import YoutubeVideoLoader
 from embedchain.models.data_type import DataType
 
 
-class DataFormatter:
+class DataFormatter(JSONSerializable):
     """
     DataFormatter is an internal utility class which abstracts the mapping for
     loaders and chunkers to the data_type entered by the user in their

+ 2 - 1
embedchain/embedchain.py

@@ -19,6 +19,7 @@ from embedchain.config import AddConfig, ChatConfig, QueryConfig
 from embedchain.config.apps.BaseAppConfig import BaseAppConfig
 from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
 from embedchain.data_formatter import DataFormatter
+from embedchain.helper_classes.json_serializable import JSONSerializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.models.data_type import DataType
 from embedchain.utils import detect_datatype
@@ -32,7 +33,7 @@ CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
 CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
 
 
-class EmbedChain:
+class EmbedChain(JSONSerializable):
     def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
         """
         Initializes the EmbedChain instance, sets up a vector DB client and

+ 180 - 0
embedchain/helper_classes/json_serializable.py

@@ -0,0 +1,180 @@
+import json
+import logging
+from typing import Any, Dict, Type, TypeVar, Union
+
+T = TypeVar("T", bound="JSONSerializable")
+
+# NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level)
+# NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level)
+
+
+def register_deserializable(cls: Type[T]) -> Type[T]:
+    """
+    A class decorator to register a class as deserializable.
+
+    When a class is decorated with @register_deserializable, it becomes
+    a part of the set of classes that the JSONSerializable class can
+    deserialize.
+
+    Deserialization is in essence loading attributes from a json file.
+    This decorator is a security measure put in place to make sure that
+    you don't load attributes that were initially part of another class.
+
+    Example:
+        @register_deserializable
+        class ChildClass(JSONSerializable):
+            def __init__(self, ...):
+                # initialization logic
+
+    Args:
+        cls (Type): The class to be registered.
+
+    Returns:
+        Type: The same class, after registration.
+    """
+    JSONSerializable.register_class_as_deserializable(cls)
+    return cls
+
+
+class JSONSerializable:
+    """
+    A class to represent a JSON serializable object.
+
+    This class provides methods to serialize and deserialize objects,
+    as well as save serialized objects to a file and load them back.
+    """
+
+    _deserializable_classes = set()  # Contains classes that are whitelisted for deserialization.
+
+    def serialize(self) -> str:
+        """
+        Serialize the object to a JSON-formatted string.
+
+        Returns:
+            str: A JSON string representation of the object.
+        """
+        try:
+            return json.dumps(self, default=self._auto_encoder, ensure_ascii=False)
+        except Exception as e:
+            logging.error(f"Serialization error: {e}")
+            return "{}"
+
+    @classmethod
+    def deserialize(cls, json_str: str) -> Any:
+        """
+        Deserialize a JSON-formatted string to an object.
+        If it fails, a default class is returned instead.
+        Note: This *returns* an instance, it's not automatically loaded on the calling class.
+
+        Example:
+            app = App.deserialize(json_str)
+
+        Args:
+            json_str (str): A JSON string representation of an object.
+
+        Returns:
+            Object: The deserialized object.
+        """
+        try:
+            return json.loads(json_str, object_hook=cls._auto_decoder)
+        except Exception as e:
+            logging.error(f"Deserialization error: {e}")
+            # Return a default instance in case of failure
+            return cls()
+
+    @staticmethod
+    def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]:
+        """
+        Automatically encode an object for JSON serialization.
+
+        Args:
+            obj (Object): The object to be encoded.
+
+        Returns:
+            dict: A dictionary representation of the object.
+        """
+        if hasattr(obj, "__dict__"):
+            dct = obj.__dict__.copy()
+            for key, value in list(
+                dct.items()
+            ):  # We use list() to get a copy of items to avoid dictionary size change during iteration.
+                try:
+                    # Recursive: If the value is an instance of a subclass of JSONSerializable,
+                    # serialize it using the JSONSerializable serialize method.
+                    if isinstance(value, JSONSerializable):
+                        serialized_value = value.serialize()
+                        # The value is stored as a serialized string.
+                        dct[key] = json.loads(serialized_value)
+                    else:
+                        json.dumps(value)  # Try to serialize the value.
+                except TypeError:
+                    del dct[key]  # If it fails, remove the key-value pair from the dictionary.
+
+            dct["__class__"] = obj.__class__.__name__
+            return dct
+        raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
+
+    @classmethod
+    def _auto_decoder(cls, dct: Dict[str, Any]) -> Any:
+        """
+        Automatically decode a dictionary to an object during JSON deserialization.
+
+        Args:
+            dct (dict): The dictionary representation of an object.
+
+        Returns:
+            Object: The decoded object or the original dictionary if decoding is not possible.
+        """
+        class_name = dct.pop("__class__", None)
+        if class_name:
+            if not hasattr(cls, "_deserializable_classes"):  # Additional safety check
+                raise AttributeError(f"`{class_name}` has no registry of allowed deserializations.")
+            if class_name not in {cl.__name__ for cl in cls._deserializable_classes}:
+                raise KeyError(f"Deserialization of class `{class_name}` is not allowed.")
+            target_class = next((cl for cl in cls._deserializable_classes if cl.__name__ == class_name), None)
+            if target_class:
+                obj = target_class.__new__(target_class)
+                for key, value in dct.items():
+                    default_value = getattr(target_class, key, None)
+                    setattr(obj, key, value or default_value)
+                return obj
+        return dct
+
+    def save_to_file(self, filename: str) -> None:
+        """
+        Save the serialized object to a file.
+
+        Args:
+            filename (str): The path to the file where the object should be saved.
+        """
+        with open(filename, "w", encoding="utf-8") as f:
+            f.write(self.serialize())
+
+    @classmethod
+    def load_from_file(cls, filename: str) -> Any:
+        """
+        Load and deserialize an object from a file.
+
+        Args:
+            filename (str): The path to the file from which the object should be loaded.
+
+        Returns:
+            Object: The deserialized object.
+        """
+        with open(filename, "r", encoding="utf-8") as f:
+            json_str = f.read()
+            return cls.deserialize(json_str)
+
+    @classmethod
+    def register_class_as_deserializable(cls, target_class: Type[T]) -> None:
+        """
+        Register a class as deserializable. This is a classmethod and globally shared.
+
+        This method adds the target class to the set of classes that
+        can be deserialized. This is a security measure to ensure only
+        whitelisted classes are deserialized.
+
+        Args:
+            target_class (Type): The class to be registered.
+        """
+        cls._deserializable_classes.add(target_class)

+ 4 - 1
embedchain/loaders/base_loader.py

@@ -1,4 +1,7 @@
-class BaseLoader:
+from embedchain.helper_classes.json_serializable import JSONSerializable
+
+
+class BaseLoader(JSONSerializable):
     def __init__(self):
         pass
 

+ 2 - 0
embedchain/loaders/docs_site_loader.py

@@ -4,9 +4,11 @@ from urllib.parse import urljoin, urlparse
 import requests
 from bs4 import BeautifulSoup
 
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
 
+@register_deserializable
 class DocsSiteLoader(BaseLoader):
     def __init__(self):
         self.visited_links = set()

+ 2 - 0
embedchain/loaders/docx_file.py

@@ -1,8 +1,10 @@
 from langchain.document_loaders import Docx2txtLoader
 
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
 
+@register_deserializable
 class DocxFileLoader(BaseLoader):
     def load_data(self, url):
         """Load data from a .docx file."""

+ 2 - 0
embedchain/loaders/local_qna_pair.py

@@ -1,6 +1,8 @@
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
 
+@register_deserializable
 class LocalQnaPairLoader(BaseLoader):
     def load_data(self, content):
         """Load data from a local QnA pair."""

+ 2 - 0
embedchain/loaders/local_text.py

@@ -1,6 +1,8 @@
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
 
+@register_deserializable
 class LocalTextLoader(BaseLoader):
     def load_data(self, content):
         """Load data from a local text file."""

+ 2 - 0
embedchain/loaders/notion.py

@@ -7,10 +7,12 @@ except ImportError:
     raise ImportError("Notion requires extra dependencies. Install with `pip install embedchain[community]`") from None
 
 
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 
 
+@register_deserializable
 class NotionLoader(BaseLoader):
     def load_data(self, source):
         """Load data from a PDF file."""

+ 2 - 0
embedchain/loaders/pdf_file.py

@@ -1,9 +1,11 @@
 from langchain.document_loaders import PyPDFLoader
 
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 
 
+@register_deserializable
 class PdfFileLoader(BaseLoader):
     def load_data(self, url):
         """Load data from a PDF file."""

+ 2 - 0
embedchain/loaders/sitemap.py

@@ -4,11 +4,13 @@ import requests
 from bs4 import BeautifulSoup
 from bs4.builder import ParserRejectedMarkup
 
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.web_page import WebPageLoader
 from embedchain.utils import is_readable
 
 
+@register_deserializable
 class SitemapLoader(BaseLoader):
     def load_data(self, sitemap_url):
         """

+ 2 - 0
embedchain/loaders/web_page.py

@@ -3,10 +3,12 @@ import logging
 import requests
 from bs4 import BeautifulSoup
 
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 
 
+@register_deserializable
 class WebPageLoader(BaseLoader):
     def load_data(self, url):
         """Load data from a web page."""

+ 2 - 0
embedchain/loaders/youtube_video.py

@@ -1,9 +1,11 @@
 from langchain.document_loaders import YoutubeLoader
 
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 
 
+@register_deserializable
 class YoutubeVideoLoader(BaseLoader):
     def load_data(self, url):
         """Load data from a Youtube video."""

+ 4 - 1
embedchain/vectordb/base_vector_db.py

@@ -1,4 +1,7 @@
-class BaseVectorDB:
+from embedchain.helper_classes.json_serializable import JSONSerializable
+
+
+class BaseVectorDB(JSONSerializable):
     """Base class for vector database."""
 
     def __init__(self):

+ 2 - 0
embedchain/vectordb/chroma_db.py

@@ -14,9 +14,11 @@ except RuntimeError:
 
 from chromadb.config import Settings
 
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.vectordb.base_vector_db import BaseVectorDB
 
 
+@register_deserializable
 class ChromaDB(BaseVectorDB):
     """Vector database using ChromaDB."""
 

+ 2 - 0
embedchain/vectordb/elasticsearch_db.py

@@ -9,10 +9,12 @@ except ImportError:
     ) from None
 
 from embedchain.config import ElasticsearchDBConfig
+from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.models.VectorDimensions import VectorDimensions
 from embedchain.vectordb.base_vector_db import BaseVectorDB
 
 
+@register_deserializable
 class ElasticsearchDB(BaseVectorDB):
     def __init__(
         self,

+ 70 - 0
tests/helper_classes/test_json_serializable.py

@@ -0,0 +1,70 @@
+import random
+import unittest
+
+from embedchain import App
+from embedchain.config import AppConfig
+from embedchain.helper_classes.json_serializable import (
+    JSONSerializable, register_deserializable)
+
+
+class TestJsonSerializable(unittest.TestCase):
+    """Test that the datatype detection is working, based on the input."""
+
+    def test_base_function(self):
+        """Test that the base premise of serialization and deserealization is working"""
+
+        @register_deserializable
+        class TestClass(JSONSerializable):
+            def __init__(self):
+                self.rng = random.random()
+
+        original_class = TestClass()
+        serial = original_class.serialize()
+
+        # Negative test to show that a new class does not have the same random number.
+        negative_test_class = TestClass()
+        self.assertNotEqual(original_class.rng, negative_test_class.rng)
+
+        # Test to show that a deserialized class has the same random number.
+        positive_test_class: TestClass = TestClass().deserialize(serial)
+        self.assertEqual(original_class.rng, positive_test_class.rng)
+        self.assertTrue(isinstance(positive_test_class, TestClass))
+
+        # Test that it works as a static method too.
+        positive_test_class: TestClass = TestClass.deserialize(serial)
+        self.assertEqual(original_class.rng, positive_test_class.rng)
+
+    # TODO: There's no reason it shouldn't work, but serialization to and from file should be tested too.
+
+    def test_registration_required(self):
+        """Test that registration is required, and that without registration the default class is returned."""
+
+        class SecondTestClass(JSONSerializable):
+            def __init__(self):
+                self.default = True
+
+        app = SecondTestClass()
+        # Make not default
+        app.default = False
+        # Serialize
+        serial = app.serialize()
+        # Deserialize. Due to the way errors are handled, it will not fail but return a default class.
+        app: SecondTestClass = SecondTestClass().deserialize(serial)
+        self.assertTrue(app.default)
+        # If we register and try again with the same serial, it should work
+        SecondTestClass.register_class_as_deserializable(SecondTestClass)
+        app: SecondTestClass = SecondTestClass().deserialize(serial)
+        self.assertFalse(app.default)
+
+    def test_recursive(self):
+        """Test recursiveness with the real app"""
+        random_id = str(random.random())
+        config = AppConfig(id=random_id)
+        # config class is set under app.config.
+        app = App(config=config)
+        # w/o recursion it would just be <embedchain.config.apps.OpenSourceAppConfig.OpenSourceAppConfig object at x>
+        s = app.serialize()
+        new_app: App = App.deserialize(s)
+        # The id of the new app is the same as the first one.
+        self.assertEqual(random_id, new_app.config.id)
+        # TODO: test deeper recursion