clip_processor.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. try:
  2. import clip
  3. import torch
  4. from PIL import Image, UnidentifiedImageError
  5. except ImportError:
  6. raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
  7. MODEL_NAME = "ViT-B/32"
  8. class ClipProcessor:
  9. @staticmethod
  10. def load_model():
  11. """Load data from a director of images."""
  12. device = "cuda" if torch.cuda.is_available() else "cpu"
  13. # load model and image preprocessing
  14. model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
  15. return model, preprocess
  16. @staticmethod
  17. def get_image_features(image_url, model, preprocess):
  18. """
  19. Applies the CLIP model to evaluate the vector representation of the supplied image
  20. """
  21. device = "cuda" if torch.cuda.is_available() else "cpu"
  22. try:
  23. # load image
  24. image = Image.open(image_url)
  25. except FileNotFoundError:
  26. raise FileNotFoundError("The supplied file does not exist`")
  27. except UnidentifiedImageError:
  28. raise UnidentifiedImageError("The supplied file is not an image`")
  29. # pre-process image
  30. processed_image = preprocess(image).unsqueeze(0).to(device)
  31. with torch.no_grad():
  32. image_features = model.encode_image(processed_image)
  33. image_features /= image_features.norm(dim=-1, keepdim=True)
  34. image_features = image_features.cpu().detach().numpy().tolist()[0]
  35. meta_data = {"url": image_url}
  36. return {"content": image_url, "embedding": image_features, "meta_data": meta_data}
  37. @staticmethod
  38. def get_text_features(query):
  39. """
  40. Applies the CLIP model to evaluate the vector representation of the supplied text
  41. """
  42. device = "cuda" if torch.cuda.is_available() else "cpu"
  43. model, preprocess = ClipProcessor.load_model()
  44. text = clip.tokenize(query).to(device)
  45. with torch.no_grad():
  46. text_features = model.encode_text(text)
  47. text_features /= text_features.norm(dim=-1, keepdim=True)
  48. return text_features.cpu().numpy().tolist()[0]