test_json_serializable.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import random
  2. import unittest
  3. from string import Template
  4. from embedchain import App
  5. from embedchain.config import AppConfig, BaseLlmConfig
  6. from embedchain.helper.json_serializable import (JSONSerializable,
  7. register_deserializable)
  8. class TestJsonSerializable(unittest.TestCase):
  9. """Test that the datatype detection is working, based on the input."""
  10. def test_base_function(self):
  11. """Test that the base premise of serialization and deserealization is working"""
  12. @register_deserializable
  13. class TestClass(JSONSerializable):
  14. def __init__(self):
  15. self.rng = random.random()
  16. original_class = TestClass()
  17. serial = original_class.serialize()
  18. # Negative test to show that a new class does not have the same random number.
  19. negative_test_class = TestClass()
  20. self.assertNotEqual(original_class.rng, negative_test_class.rng)
  21. # Test to show that a deserialized class has the same random number.
  22. positive_test_class: TestClass = TestClass().deserialize(serial)
  23. self.assertEqual(original_class.rng, positive_test_class.rng)
  24. self.assertTrue(isinstance(positive_test_class, TestClass))
  25. # Test that it works as a static method too.
  26. positive_test_class: TestClass = TestClass.deserialize(serial)
  27. self.assertEqual(original_class.rng, positive_test_class.rng)
  28. # TODO: There's no reason it shouldn't work, but serialization to and from file should be tested too.
  29. def test_registration_required(self):
  30. """Test that registration is required, and that without registration the default class is returned."""
  31. class SecondTestClass(JSONSerializable):
  32. def __init__(self):
  33. self.default = True
  34. app = SecondTestClass()
  35. # Make not default
  36. app.default = False
  37. # Serialize
  38. serial = app.serialize()
  39. # Deserialize. Due to the way errors are handled, it will not fail but return a default class.
  40. app: SecondTestClass = SecondTestClass().deserialize(serial)
  41. self.assertTrue(app.default)
  42. # If we register and try again with the same serial, it should work
  43. SecondTestClass.register_class_as_deserializable(SecondTestClass)
  44. app: SecondTestClass = SecondTestClass().deserialize(serial)
  45. self.assertFalse(app.default)
  46. def test_recursive(self):
  47. """Test recursiveness with the real app"""
  48. random_id = str(random.random())
  49. config = AppConfig(id=random_id, collect_metrics=False)
  50. # config class is set under app.config.
  51. app = App(config=config)
  52. # w/o recursion it would just be <embedchain.config.apps.OpenSourceAppConfig.OpenSourceAppConfig object at x>
  53. s = app.serialize()
  54. new_app: App = App.deserialize(s)
  55. # The id of the new app is the same as the first one.
  56. self.assertEqual(random_id, new_app.config.id)
  57. # We have proven that a nested class (app.config) can be serialized and deserialized just the same.
  58. # TODO: test deeper recursion
  59. def test_special_subclasses(self):
  60. """Test special subclasses that are not serializable by default."""
  61. # Template
  62. config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
  63. s = config.serialize()
  64. new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
  65. self.assertEqual(config.template.template, new_config.template.template)