json_serializable.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import json
  2. import logging
  3. from string import Template
  4. from typing import Any, Type, TypeVar, Union
  5. T = TypeVar("T", bound="JSONSerializable")
  6. # NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level)
  7. # NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level)
  8. logger = logging.getLogger(__name__)
  9. def register_deserializable(cls: Type[T]) -> Type[T]:
  10. """
  11. A class decorator to register a class as deserializable.
  12. When a class is decorated with @register_deserializable, it becomes
  13. a part of the set of classes that the JSONSerializable class can
  14. deserialize.
  15. Deserialization is in essence loading attributes from a json file.
  16. This decorator is a security measure put in place to make sure that
  17. you don't load attributes that were initially part of another class.
  18. Example:
  19. @register_deserializable
  20. class ChildClass(JSONSerializable):
  21. def __init__(self, ...):
  22. # initialization logic
  23. Args:
  24. cls (Type): The class to be registered.
  25. Returns:
  26. Type: The same class, after registration.
  27. """
  28. JSONSerializable._register_class_as_deserializable(cls)
  29. return cls
  30. class JSONSerializable:
  31. """
  32. A class to represent a JSON serializable object.
  33. This class provides methods to serialize and deserialize objects,
  34. as well as to save serialized objects to a file and load them back.
  35. """
  36. _deserializable_classes = set() # Contains classes that are whitelisted for deserialization.
  37. def serialize(self) -> str:
  38. """
  39. Serialize the object to a JSON-formatted string.
  40. Returns:
  41. str: A JSON string representation of the object.
  42. """
  43. try:
  44. return json.dumps(self, default=self._auto_encoder, ensure_ascii=False)
  45. except Exception as e:
  46. logger.error(f"Serialization error: {e}")
  47. return "{}"
  48. @classmethod
  49. def deserialize(cls, json_str: str) -> Any:
  50. """
  51. Deserialize a JSON-formatted string to an object.
  52. If it fails, a default class is returned instead.
  53. Note: This *returns* an instance, it's not automatically loaded on the calling class.
  54. Example:
  55. app = App.deserialize(json_str)
  56. Args:
  57. json_str (str): A JSON string representation of an object.
  58. Returns:
  59. Object: The deserialized object.
  60. """
  61. try:
  62. return json.loads(json_str, object_hook=cls._auto_decoder)
  63. except Exception as e:
  64. logger.error(f"Deserialization error: {e}")
  65. # Return a default instance in case of failure
  66. return cls()
  67. @staticmethod
  68. def _auto_encoder(obj: Any) -> Union[dict[str, Any], None]:
  69. """
  70. Automatically encode an object for JSON serialization.
  71. Args:
  72. obj (Object): The object to be encoded.
  73. Returns:
  74. dict: A dictionary representation of the object.
  75. """
  76. if hasattr(obj, "__dict__"):
  77. dct = obj.__dict__.copy()
  78. for key, value in list(
  79. dct.items()
  80. ): # We use list() to get a copy of items to avoid dictionary size change during iteration.
  81. try:
  82. # Recursive: If the value is an instance of a subclass of JSONSerializable,
  83. # serialize it using the JSONSerializable serialize method.
  84. if isinstance(value, JSONSerializable):
  85. serialized_value = value.serialize()
  86. # The value is stored as a serialized string.
  87. dct[key] = json.loads(serialized_value)
  88. # Custom rules (subclass is not json serializable by default)
  89. elif isinstance(value, Template):
  90. dct[key] = {"__type__": "Template", "data": value.template}
  91. # Future custom types we can follow a similar pattern
  92. # elif isinstance(value, SomeOtherType):
  93. # dct[key] = {
  94. # "__type__": "SomeOtherType",
  95. # "data": value.some_method()
  96. # }
  97. # NOTE: Keep in mind that this logic needs to be applied to the decoder too.
  98. else:
  99. json.dumps(value) # Try to serialize the value.
  100. except TypeError:
  101. del dct[key] # If it fails, remove the key-value pair from the dictionary.
  102. dct["__class__"] = obj.__class__.__name__
  103. return dct
  104. raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
  105. @classmethod
  106. def _auto_decoder(cls, dct: dict[str, Any]) -> Any:
  107. """
  108. Automatically decode a dictionary to an object during JSON deserialization.
  109. Args:
  110. dct (dict): The dictionary representation of an object.
  111. Returns:
  112. Object: The decoded object or the original dictionary if decoding is not possible.
  113. """
  114. class_name = dct.pop("__class__", None)
  115. if class_name:
  116. if not hasattr(cls, "_deserializable_classes"): # Additional safety check
  117. raise AttributeError(f"`{class_name}` has no registry of allowed deserializations.")
  118. if class_name not in {cl.__name__ for cl in cls._deserializable_classes}:
  119. raise KeyError(f"Deserialization of class `{class_name}` is not allowed.")
  120. target_class = next((cl for cl in cls._deserializable_classes if cl.__name__ == class_name), None)
  121. if target_class:
  122. obj = target_class.__new__(target_class)
  123. for key, value in dct.items():
  124. if isinstance(value, dict) and "__type__" in value:
  125. if value["__type__"] == "Template":
  126. value = Template(value["data"])
  127. # For future custom types we can follow a similar pattern
  128. # elif value["__type__"] == "SomeOtherType":
  129. # value = SomeOtherType.some_constructor(value["data"])
  130. default_value = getattr(target_class, key, None)
  131. setattr(obj, key, value or default_value)
  132. return obj
  133. return dct
  134. def save_to_file(self, filename: str) -> None:
  135. """
  136. Save the serialized object to a file.
  137. Args:
  138. filename (str): The path to the file where the object should be saved.
  139. """
  140. with open(filename, "w", encoding="utf-8") as f:
  141. f.write(self.serialize())
  142. @classmethod
  143. def load_from_file(cls, filename: str) -> Any:
  144. """
  145. Load and deserialize an object from a file.
  146. Args:
  147. filename (str): The path to the file from which the object should be loaded.
  148. Returns:
  149. Object: The deserialized object.
  150. """
  151. with open(filename, "r", encoding="utf-8") as f:
  152. json_str = f.read()
  153. return cls.deserialize(json_str)
  154. @classmethod
  155. def _register_class_as_deserializable(cls, target_class: Type[T]) -> None:
  156. """
  157. Register a class as deserializable. This is a classmethod and globally shared.
  158. This method adds the target class to the set of classes that
  159. can be deserialized. This is a security measure to ensure only
  160. whitelisted classes are deserialized.
  161. Args:
  162. target_class (Type): The class to be registered.
  163. """
  164. cls._deserializable_classes.add(target_class)