Sfoglia il codice sorgente

fix: serialize non serializable (#589)

cachho 1 anno fa
parent
commit
1864f4cb38

+ 17 - 0
embedchain/helper/json_serializable.py

@@ -1,5 +1,6 @@
 import json
 import json
 import logging
 import logging
+from string import Template
 from typing import Any, Dict, Type, TypeVar, Union
 from typing import Any, Dict, Type, TypeVar, Union
 
 
 T = TypeVar("T", bound="JSONSerializable")
 T = TypeVar("T", bound="JSONSerializable")
@@ -105,6 +106,16 @@ class JSONSerializable:
                         serialized_value = value.serialize()
                         serialized_value = value.serialize()
                         # The value is stored as a serialized string.
                         # The value is stored as a serialized string.
                         dct[key] = json.loads(serialized_value)
                         dct[key] = json.loads(serialized_value)
+                    # Custom rules (subclass is not json serializable by default)
+                    elif isinstance(value, Template):
+                        dct[key] = {"__type__": "Template", "data": value.template}
+                    # Future custom types we can follow a similar pattern
+                    # elif isinstance(value, SomeOtherType):
+                    #     dct[key] = {
+                    #         "__type__": "SomeOtherType",
+                    #         "data": value.some_method()
+                    #     }
+                    # NOTE: Keep in mind that this logic needs to be applied to the decoder too.
                     else:
                     else:
                         json.dumps(value)  # Try to serialize the value.
                         json.dumps(value)  # Try to serialize the value.
                 except TypeError:
                 except TypeError:
@@ -135,6 +146,12 @@ class JSONSerializable:
             if target_class:
             if target_class:
                 obj = target_class.__new__(target_class)
                 obj = target_class.__new__(target_class)
                 for key, value in dct.items():
                 for key, value in dct.items():
+                    if isinstance(value, dict) and "__type__" in value:
+                        if value["__type__"] == "Template":
+                            value = Template(value["data"])
+                        # For future custom types we can follow a similar pattern
+                        # elif value["__type__"] == "SomeOtherType":
+                        #     value = SomeOtherType.some_constructor(value["data"])
                     default_value = getattr(target_class, key, None)
                     default_value = getattr(target_class, key, None)
                     setattr(obj, key, value or default_value)
                     setattr(obj, key, value or default_value)
                 return obj
                 return obj

+ 10 - 1
tests/helper_classes/test_json_serializable.py

@@ -1,8 +1,9 @@
 import random
 import random
 import unittest
 import unittest
+from string import Template
 
 
 from embedchain import App
 from embedchain import App
-from embedchain.config import AppConfig
+from embedchain.config import AppConfig, BaseLlmConfig
 from embedchain.helper.json_serializable import (JSONSerializable,
 from embedchain.helper.json_serializable import (JSONSerializable,
                                                  register_deserializable)
                                                  register_deserializable)
 
 
@@ -69,3 +70,11 @@ class TestJsonSerializable(unittest.TestCase):
         self.assertEqual(random_id, new_app.config.id)
         self.assertEqual(random_id, new_app.config.id)
         # We have proven that a nested class (app.config) can be serialized and deserialized just the same.
         # We have proven that a nested class (app.config) can be serialized and deserialized just the same.
         # TODO: test deeper recursion
         # TODO: test deeper recursion
+
+    def test_special_subclasses(self):
+        """Test special subclasses that are not serializable by default."""
+        # Template
+        config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
+        s = config.serialize()
+        new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
+        self.assertEqual(config.template.template, new_config.template.template)