clip_processor.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. try:
  2. from PIL import Image, UnidentifiedImageError
  3. from sentence_transformers import SentenceTransformer
  4. except ImportError:
  5. raise ImportError(
  6. "Images requires extra dependencies. Install with `pip install 'embedchain[images]'"
  7. ) from None
  8. MODEL_NAME = "clip-ViT-B-32"
  9. class ClipProcessor:
  10. @staticmethod
  11. def load_model():
  12. """Load data from a director of images."""
  13. # load model and image preprocessing
  14. model = SentenceTransformer(MODEL_NAME)
  15. return model
  16. @staticmethod
  17. def get_image_features(image_url, model):
  18. """
  19. Applies the CLIP model to evaluate the vector representation of the supplied image
  20. """
  21. try:
  22. # load image
  23. image = Image.open(image_url)
  24. except FileNotFoundError:
  25. raise FileNotFoundError("The supplied file does not exist`")
  26. except UnidentifiedImageError:
  27. raise UnidentifiedImageError("The supplied file is not an image`")
  28. image_features = model.encode(image)
  29. meta_data = {"url": image_url}
  30. return {"content": image_url, "embedding": image_features.tolist(), "meta_data": meta_data}
  31. @staticmethod
  32. def get_text_features(query):
  33. """
  34. Applies the CLIP model to evaluate the vector representation of the supplied text
  35. """
  36. model = ClipProcessor.load_model()
  37. text_features = model.encode(query)
  38. return text_features.tolist()