test_clip_processor.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import tempfile
  2. import unittest
  3. import os
  4. import urllib
  5. from PIL import Image
  6. from embedchain.models.clip_processor import ClipProcessor
  7. class ClipProcessorTest(unittest.TestCase):
  8. def test_load_model(self):
  9. # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
  10. model, preprocess = ClipProcessor.load_model()
  11. # Assert that the model is not None.
  12. self.assertIsNotNone(model)
  13. # Assert that the preprocess is not None.
  14. self.assertIsNotNone(preprocess)
  15. def test_get_image_features(self):
  16. # Clone the image to a temporary folder.
  17. with tempfile.TemporaryDirectory() as tmp_dir:
  18. urllib.request.urlretrieve(
  19. 'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
  20. "image.jpg")
  21. image = Image.open("image.jpg")
  22. image.save(os.path.join(tmp_dir, "image.jpg"))
  23. # Get the image features.
  24. model, preprocess = ClipProcessor.load_model()
  25. ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
  26. # Delete the temporary file.
  27. os.remove(os.path.join(tmp_dir, "image.jpg"))
  28. # Assert that the test passes.
  29. self.assertTrue(True)
  30. def test_get_text_features(self):
  31. # Test that the `get_text_features()` method returns a list containing the text embedding.
  32. query = "This is a text query."
  33. model, preprocess = ClipProcessor.load_model()
  34. text_features = ClipProcessor.get_text_features(query)
  35. # Assert that the text embedding is not None.
  36. self.assertIsNotNone(text_features)
  37. # Assert that the text embedding is a list of floats.
  38. self.assertIsInstance(text_features, list)
  39. # Assert that the text embedding has the correct length.
  40. self.assertEqual(len(text_features), 512)