clip_processor.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. try:
  2. import clip
  3. import torch
  4. from PIL import Image, UnidentifiedImageError
  5. except ImportError:
  6. raise ImportError(
  7. "Images requires extra dependencies. Install with `pip install 'embedchain[images]' git+https://github.com/openai/CLIP.git#a1d0717`" # noqa: E501
  8. ) from None
  9. MODEL_NAME = "ViT-B/32"
  10. class ClipProcessor:
  11. @staticmethod
  12. def load_model():
  13. """Load data from a director of images."""
  14. device = "cuda" if torch.cuda.is_available() else "cpu"
  15. # load model and image preprocessing
  16. model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
  17. return model, preprocess
  18. @staticmethod
  19. def get_image_features(image_url, model, preprocess):
  20. """
  21. Applies the CLIP model to evaluate the vector representation of the supplied image
  22. """
  23. device = "cuda" if torch.cuda.is_available() else "cpu"
  24. try:
  25. # load image
  26. image = Image.open(image_url)
  27. except FileNotFoundError:
  28. raise FileNotFoundError("The supplied file does not exist`")
  29. except UnidentifiedImageError:
  30. raise UnidentifiedImageError("The supplied file is not an image`")
  31. # pre-process image
  32. processed_image = preprocess(image).unsqueeze(0).to(device)
  33. with torch.no_grad():
  34. image_features = model.encode_image(processed_image)
  35. image_features /= image_features.norm(dim=-1, keepdim=True)
  36. image_features = image_features.cpu().detach().numpy().tolist()[0]
  37. meta_data = {"url": image_url}
  38. return {"content": image_url, "embedding": image_features, "meta_data": meta_data}
  39. @staticmethod
  40. def get_text_features(query):
  41. """
  42. Applies the CLIP model to evaluate the vector representation of the supplied text
  43. """
  44. device = "cuda" if torch.cuda.is_available() else "cpu"
  45. model, preprocess = ClipProcessor.load_model()
  46. text = clip.tokenize(query).to(device)
  47. with torch.no_grad():
  48. text_features = model.encode_text(text)
  49. text_features /= text_features.norm(dim=-1, keepdim=True)
  50. return text_features.cpu().numpy().tolist()[0]