test_clip_processor.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # import os
  2. # import tempfile
  3. # import urllib
  4. # import pytest
  5. # from PIL import Image
  6. # TODO: Uncomment after fixing clip dependency issue
  7. # from embedchain.models.clip_processor import ClipProcessor
  8. # class TestClipProcessor:
  9. # @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.")
  10. # def test_load_model(self):
  11. # # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
  12. # model, preprocess = ClipProcessor.load_model()
  13. # assert model is not None
  14. # assert preprocess is not None
  15. # @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.")
  16. # def test_get_image_features(self):
  17. # # Clone the image to a temporary folder.
  18. # with tempfile.TemporaryDirectory() as tmp_dir:
  19. # urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
  20. # image = Image.open("image.jpg")
  21. # image.save(os.path.join(tmp_dir, "image.jpg"))
  22. # # Get the image features.
  23. # model, preprocess = ClipProcessor.load_model()
  24. # ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
  25. # # Delete the temporary file.
  26. # os.remove(os.path.join(tmp_dir, "image.jpg"))
  27. # @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.")
  28. # def test_get_text_features(self):
  29. # # Test that the `get_text_features()` method returns a list containing the text embedding.
  30. # query = "This is a text query."
  31. # model, preprocess = ClipProcessor.load_model()
  32. # text_features = ClipProcessor.get_text_features(query)
  33. # # Assert that the text embedding is not None.
  34. # assert text_features is not None
  35. # # Assert that the text embedding is a list of floats.
  36. # assert isinstance(text_features, list)
  37. # # Assert that the text embedding has the correct length.
  38. # assert len(text_features) == 512