json_serializable.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import json
  2. import logging
  3. from typing import Any, Dict, Type, TypeVar, Union
  4. T = TypeVar("T", bound="JSONSerializable")
  5. # NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level)
  6. # NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level)
  7. def register_deserializable(cls: Type[T]) -> Type[T]:
  8. """
  9. A class decorator to register a class as deserializable.
  10. When a class is decorated with @register_deserializable, it becomes
  11. a part of the set of classes that the JSONSerializable class can
  12. deserialize.
  13. Deserialization is in essence loading attributes from a json file.
  14. This decorator is a security measure put in place to make sure that
  15. you don't load attributes that were initially part of another class.
  16. Example:
  17. @register_deserializable
  18. class ChildClass(JSONSerializable):
  19. def __init__(self, ...):
  20. # initialization logic
  21. Args:
  22. cls (Type): The class to be registered.
  23. Returns:
  24. Type: The same class, after registration.
  25. """
  26. JSONSerializable.register_class_as_deserializable(cls)
  27. return cls
  28. class JSONSerializable:
  29. """
  30. A class to represent a JSON serializable object.
  31. This class provides methods to serialize and deserialize objects,
  32. as well as save serialized objects to a file and load them back.
  33. """
  34. _deserializable_classes = set() # Contains classes that are whitelisted for deserialization.
  35. def serialize(self) -> str:
  36. """
  37. Serialize the object to a JSON-formatted string.
  38. Returns:
  39. str: A JSON string representation of the object.
  40. """
  41. try:
  42. return json.dumps(self, default=self._auto_encoder, ensure_ascii=False)
  43. except Exception as e:
  44. logging.error(f"Serialization error: {e}")
  45. return "{}"
  46. @classmethod
  47. def deserialize(cls, json_str: str) -> Any:
  48. """
  49. Deserialize a JSON-formatted string to an object.
  50. If it fails, a default class is returned instead.
  51. Note: This *returns* an instance, it's not automatically loaded on the calling class.
  52. Example:
  53. app = App.deserialize(json_str)
  54. Args:
  55. json_str (str): A JSON string representation of an object.
  56. Returns:
  57. Object: The deserialized object.
  58. """
  59. try:
  60. return json.loads(json_str, object_hook=cls._auto_decoder)
  61. except Exception as e:
  62. logging.error(f"Deserialization error: {e}")
  63. # Return a default instance in case of failure
  64. return cls()
  65. @staticmethod
  66. def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]:
  67. """
  68. Automatically encode an object for JSON serialization.
  69. Args:
  70. obj (Object): The object to be encoded.
  71. Returns:
  72. dict: A dictionary representation of the object.
  73. """
  74. if hasattr(obj, "__dict__"):
  75. dct = obj.__dict__.copy()
  76. for key, value in list(
  77. dct.items()
  78. ): # We use list() to get a copy of items to avoid dictionary size change during iteration.
  79. try:
  80. # Recursive: If the value is an instance of a subclass of JSONSerializable,
  81. # serialize it using the JSONSerializable serialize method.
  82. if isinstance(value, JSONSerializable):
  83. serialized_value = value.serialize()
  84. # The value is stored as a serialized string.
  85. dct[key] = json.loads(serialized_value)
  86. else:
  87. json.dumps(value) # Try to serialize the value.
  88. except TypeError:
  89. del dct[key] # If it fails, remove the key-value pair from the dictionary.
  90. dct["__class__"] = obj.__class__.__name__
  91. return dct
  92. raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
  93. @classmethod
  94. def _auto_decoder(cls, dct: Dict[str, Any]) -> Any:
  95. """
  96. Automatically decode a dictionary to an object during JSON deserialization.
  97. Args:
  98. dct (dict): The dictionary representation of an object.
  99. Returns:
  100. Object: The decoded object or the original dictionary if decoding is not possible.
  101. """
  102. class_name = dct.pop("__class__", None)
  103. if class_name:
  104. if not hasattr(cls, "_deserializable_classes"): # Additional safety check
  105. raise AttributeError(f"`{class_name}` has no registry of allowed deserializations.")
  106. if class_name not in {cl.__name__ for cl in cls._deserializable_classes}:
  107. raise KeyError(f"Deserialization of class `{class_name}` is not allowed.")
  108. target_class = next((cl for cl in cls._deserializable_classes if cl.__name__ == class_name), None)
  109. if target_class:
  110. obj = target_class.__new__(target_class)
  111. for key, value in dct.items():
  112. default_value = getattr(target_class, key, None)
  113. setattr(obj, key, value or default_value)
  114. return obj
  115. return dct
  116. def save_to_file(self, filename: str) -> None:
  117. """
  118. Save the serialized object to a file.
  119. Args:
  120. filename (str): The path to the file where the object should be saved.
  121. """
  122. with open(filename, "w", encoding="utf-8") as f:
  123. f.write(self.serialize())
  124. @classmethod
  125. def load_from_file(cls, filename: str) -> Any:
  126. """
  127. Load and deserialize an object from a file.
  128. Args:
  129. filename (str): The path to the file from which the object should be loaded.
  130. Returns:
  131. Object: The deserialized object.
  132. """
  133. with open(filename, "r", encoding="utf-8") as f:
  134. json_str = f.read()
  135. return cls.deserialize(json_str)
  136. @classmethod
  137. def register_class_as_deserializable(cls, target_class: Type[T]) -> None:
  138. """
  139. Register a class as deserializable. This is a classmethod and globally shared.
  140. This method adds the target class to the set of classes that
  141. can be deserialized. This is a security measure to ensure only
  142. whitelisted classes are deserialized.
  143. Args:
  144. target_class (Type): The class to be registered.
  145. """
  146. cls._deserializable_classes.add(target_class)