|
@@ -1,29 +1,23 @@
|
|
|
-import tempfile
|
|
|
-import unittest
|
|
|
import os
|
|
|
+import tempfile
|
|
|
import urllib
|
|
|
+
|
|
|
from PIL import Image
|
|
|
-from embedchain.models.clip_processor import ClipProcessor
|
|
|
|
|
|
+from embedchain.models.clip_processor import ClipProcessor
|
|
|
|
|
|
-class ClipProcessorTest(unittest.TestCase):
|
|
|
|
|
|
+class TestClipProcessor:
|
|
|
def test_load_model(self):
|
|
|
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
|
|
|
model, preprocess = ClipProcessor.load_model()
|
|
|
-
|
|
|
- # Assert that the model is not None.
|
|
|
- self.assertIsNotNone(model)
|
|
|
-
|
|
|
- # Assert that the preprocess is not None.
|
|
|
- self.assertIsNotNone(preprocess)
|
|
|
+ assert model is not None
|
|
|
+ assert preprocess is not None
|
|
|
|
|
|
def test_get_image_features(self):
|
|
|
# Clone the image to a temporary folder.
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
|
- urllib.request.urlretrieve(
|
|
|
- 'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
|
|
|
- "image.jpg")
|
|
|
+ urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
|
|
|
|
|
|
image = Image.open("image.jpg")
|
|
|
image.save(os.path.join(tmp_dir, "image.jpg"))
|
|
@@ -35,9 +29,6 @@ class ClipProcessorTest(unittest.TestCase):
|
|
|
# Delete the temporary file.
|
|
|
os.remove(os.path.join(tmp_dir, "image.jpg"))
|
|
|
|
|
|
- # Assert that the test passes.
|
|
|
- self.assertTrue(True)
|
|
|
-
|
|
|
def test_get_text_features(self):
|
|
|
# Test that the `get_text_features()` method returns a list containing the text embedding.
|
|
|
query = "This is a text query."
|
|
@@ -46,10 +37,10 @@ class ClipProcessorTest(unittest.TestCase):
|
|
|
text_features = ClipProcessor.get_text_features(query)
|
|
|
|
|
|
# Assert that the text embedding is not None.
|
|
|
- self.assertIsNotNone(text_features)
|
|
|
+ assert text_features is not None
|
|
|
|
|
|
# Assert that the text embedding is a list of floats.
|
|
|
- self.assertIsInstance(text_features, list)
|
|
|
+ assert isinstance(text_features, list)
|
|
|
|
|
|
# Assert that the text embedding has the correct length.
|
|
|
- self.assertEqual(len(text_features), 512)
|
|
|
+ assert len(text_features) == 512
|