Bläddra i källkod

Added Clip dependency (#778)

Rupesh Bansal 1 år sedan
förälder
incheckning
19a9141c2d

+ 3 - 3
embedchain/loaders/images.py

@@ -16,15 +16,15 @@ class ImagesLoader(BaseLoader):
         # load model and image preprocessing
         from embedchain.models.clip_processor import ClipProcessor
 
-        model, preprocess = ClipProcessor.load_model()
+        model = ClipProcessor.load_model()
         if os.path.isfile(image_url):
-            data = [ClipProcessor.get_image_features(image_url, model, preprocess)]
+            data = [ClipProcessor.get_image_features(image_url, model)]
         else:
             data = []
             for filename in os.listdir(image_url):
                 filepath = os.path.join(image_url, filename)
                 try:
-                    data.append(ClipProcessor.get_image_features(filepath, model, preprocess))
+                    data.append(ClipProcessor.get_image_features(filepath, model))
                 except Exception as e:
                     # Log the file that was not loaded
                     logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))

+ 11 - 27
embedchain/models/clip_processor.py

@@ -1,31 +1,27 @@
 try:
-    import clip
-    import torch
     from PIL import Image, UnidentifiedImageError
+    from sentence_transformers import SentenceTransformer
 except ImportError:
     raise ImportError(
-        "Images requires extra dependencies. Install with `pip install 'embedchain[images]' git+https://github.com/openai/CLIP.git#a1d0717`"  # noqa: E501
+        "Images requires extra dependencies. Install with `pip install 'embedchain[images]'"
     ) from None
 
-MODEL_NAME = "ViT-B/32"
+MODEL_NAME = "clip-ViT-B-32"
 
 
 class ClipProcessor:
     @staticmethod
     def load_model():
         """Load data from a director of images."""
-        device = "cuda" if torch.cuda.is_available() else "cpu"
-
         # load model and image preprocessing
-        model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
-        return model, preprocess
+        model = SentenceTransformer(MODEL_NAME)
+        return model
 
     @staticmethod
-    def get_image_features(image_url, model, preprocess):
+    def get_image_features(image_url, model):
         """
         Applies the CLIP model to evaluate the vector representation of the supplied image
         """
-        device = "cuda" if torch.cuda.is_available() else "cpu"
         try:
             # load image
             image = Image.open(image_url)
@@ -34,27 +30,15 @@ class ClipProcessor:
         except UnidentifiedImageError:
             raise UnidentifiedImageError("The supplied file is not an image`")
 
-        # pre-process image
-        processed_image = preprocess(image).unsqueeze(0).to(device)
-        with torch.no_grad():
-            image_features = model.encode_image(processed_image)
-            image_features /= image_features.norm(dim=-1, keepdim=True)
-
-        image_features = image_features.cpu().detach().numpy().tolist()[0]
+        image_features = model.encode(image)
         meta_data = {"url": image_url}
-        return {"content": image_url, "embedding": image_features, "meta_data": meta_data}
+        return {"content": image_url, "embedding": image_features.tolist(), "meta_data": meta_data}
 
     @staticmethod
     def get_text_features(query):
         """
         Applies the CLIP model to evaluate the vector representation of the supplied text
         """
-        device = "cuda" if torch.cuda.is_available() else "cpu"
-
-        model, preprocess = ClipProcessor.load_model()
-        text = clip.tokenize(query).to(device)
-        with torch.no_grad():
-            text_features = model.encode_text(text)
-            text_features /= text_features.norm(dim=-1, keepdim=True)
-
-        return text_features.cpu().numpy().tolist()[0]
+        model = ClipProcessor.load_model()
+        text_features = model.encode(query)
+        return text_features.tolist()

+ 2 - 1
embedchain/utils.py

@@ -128,7 +128,8 @@ def detect_datatype(source: Any) -> DataType:
     formatted_source = format_source(str(source), 30)
 
     if url:
-        from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
+        from langchain.document_loaders.youtube import \
+            ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")

+ 31 - 39
tests/models/test_clip_processor.py

@@ -1,51 +1,43 @@
-# import os
-# import tempfile
-# import urllib
+import os
+import tempfile
+import urllib
 
-# import pytest
-# from PIL import Image
+from PIL import Image
 
-# TODO: Uncomment after fixing clip dependency issue
-# from embedchain.models.clip_processor import ClipProcessor
+from embedchain.models.clip_processor import ClipProcessor
 
 
-# class TestClipProcessor:
-#     @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.")
-#     def test_load_model(self):
-#         # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
-#         model, preprocess = ClipProcessor.load_model()
-#         assert model is not None
-#         assert preprocess is not None
+class TestClipProcessor:
+    def test_load_model(self):
+        # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
+        model = ClipProcessor.load_model()
+        assert model is not None
 
-#     @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.")
-#     def test_get_image_features(self):
-#         # Clone the image to a temporary folder.
-#         with tempfile.TemporaryDirectory() as tmp_dir:
-#             urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
+    def test_get_image_features(self):
+        # Clone the image to a temporary folder.
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
 
-#             image = Image.open("image.jpg")
-#             image.save(os.path.join(tmp_dir, "image.jpg"))
+            image = Image.open("image.jpg")
+            image.save(os.path.join(tmp_dir, "image.jpg"))
 
-#             # Get the image features.
-#             model, preprocess = ClipProcessor.load_model()
-#             ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
+            # Get the image features.
+            model = ClipProcessor.load_model()
+            ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model)
 
-#             # Delete the temporary file.
-#             os.remove(os.path.join(tmp_dir, "image.jpg"))
+            # Delete the temporary file.
+            os.remove(os.path.join(tmp_dir, "image.jpg"))
 
-#     @pytest.mark.xfail(reason="This test is failing because of the missing CLIP dependency.")
-#     def test_get_text_features(self):
-#         # Test that the `get_text_features()` method returns a list containing the text embedding.
-#         query = "This is a text query."
-#         model, preprocess = ClipProcessor.load_model()
+    def test_get_text_features(self):
+        # Test that the `get_text_features()` method returns a list containing the text embedding.
+        query = "This is a text query."
+        text_features = ClipProcessor.get_text_features(query)
 
-#         text_features = ClipProcessor.get_text_features(query)
+        # Assert that the text embedding is not None.
+        assert text_features is not None
 
-#         # Assert that the text embedding is not None.
-#         assert text_features is not None
+        # Assert that the text embedding is a list of floats.
+        assert isinstance(text_features, list)
 
-#         # Assert that the text embedding is a list of floats.
-#         assert isinstance(text_features, list)
-
-#         # Assert that the text embedding has the correct length.
-#         assert len(text_features) == 512
+        # Assert that the text embedding has the correct length.
+        assert len(text_features) == 512