test_clip_processor.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import os
  2. import tempfile
  3. import urllib
  4. from PIL import Image
  5. from embedchain.models.clip_processor import ClipProcessor
  6. class TestClipProcessor:
  7. def test_load_model(self):
  8. # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
  9. model, preprocess = ClipProcessor.load_model()
  10. assert model is not None
  11. assert preprocess is not None
  12. def test_get_image_features(self):
  13. # Clone the image to a temporary folder.
  14. with tempfile.TemporaryDirectory() as tmp_dir:
  15. urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
  16. image = Image.open("image.jpg")
  17. image.save(os.path.join(tmp_dir, "image.jpg"))
  18. # Get the image features.
  19. model, preprocess = ClipProcessor.load_model()
  20. ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
  21. # Delete the temporary file.
  22. os.remove(os.path.join(tmp_dir, "image.jpg"))
  23. def test_get_text_features(self):
  24. # Test that the `get_text_features()` method returns a list containing the text embedding.
  25. query = "This is a text query."
  26. model, preprocess = ClipProcessor.load_model()
  27. text_features = ClipProcessor.get_text_features(query)
  28. # Assert that the text embedding is not None.
  29. assert text_features is not None
  30. # Assert that the text embedding is a list of floats.
  31. assert isinstance(text_features, list)
  32. # Assert that the text embedding has the correct length.
  33. assert len(text_features) == 512