test_json_serializable.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import random
  2. import unittest
  3. from string import Template
  4. from embedchain import App
  5. from embedchain.config import AppConfig, BaseLlmConfig
  6. from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
  7. class TestJsonSerializable(unittest.TestCase):
  8. """Test that the datatype detection is working, based on the input."""
  9. def test_base_function(self):
  10. """Test that the base premise of serialization and deserealization is working"""
  11. @register_deserializable
  12. class TestClass(JSONSerializable):
  13. def __init__(self):
  14. self.rng = random.random()
  15. original_class = TestClass()
  16. serial = original_class.serialize()
  17. # Negative test to show that a new class does not have the same random number.
  18. negative_test_class = TestClass()
  19. self.assertNotEqual(original_class.rng, negative_test_class.rng)
  20. # Test to show that a deserialized class has the same random number.
  21. positive_test_class: TestClass = TestClass().deserialize(serial)
  22. self.assertEqual(original_class.rng, positive_test_class.rng)
  23. self.assertTrue(isinstance(positive_test_class, TestClass))
  24. # Test that it works as a static method too.
  25. positive_test_class: TestClass = TestClass.deserialize(serial)
  26. self.assertEqual(original_class.rng, positive_test_class.rng)
  27. # TODO: There's no reason it shouldn't work, but serialization to and from file should be tested too.
  28. def test_registration_required(self):
  29. """Test that registration is required, and that without registration the default class is returned."""
  30. class SecondTestClass(JSONSerializable):
  31. def __init__(self):
  32. self.default = True
  33. app = SecondTestClass()
  34. # Make not default
  35. app.default = False
  36. # Serialize
  37. serial = app.serialize()
  38. # Deserialize. Due to the way errors are handled, it will not fail but return a default class.
  39. app: SecondTestClass = SecondTestClass().deserialize(serial)
  40. self.assertTrue(app.default)
  41. # If we register and try again with the same serial, it should work
  42. SecondTestClass._register_class_as_deserializable(SecondTestClass)
  43. app: SecondTestClass = SecondTestClass().deserialize(serial)
  44. self.assertFalse(app.default)
  45. def test_recursive(self):
  46. """Test recursiveness with the real app"""
  47. random_id = str(random.random())
  48. config = AppConfig(id=random_id, collect_metrics=False)
  49. # config class is set under app.config.
  50. app = App(config=config)
  51. s = app.serialize()
  52. new_app: App = App.deserialize(s)
  53. # The id of the new app is the same as the first one.
  54. self.assertEqual(random_id, new_app.config.id)
  55. # We have proven that a nested class (app.config) can be serialized and deserialized just the same.
  56. # TODO: test deeper recursion
  57. def test_special_subclasses(self):
  58. """Test special subclasses that are not serializable by default."""
  59. # Template
  60. config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
  61. s = config.serialize()
  62. new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
  63. self.assertEqual(config.template.template, new_config.template.template)