clip_processor.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. try:
  2. import torch
  3. import clip
  4. from PIL import Image, UnidentifiedImageError
  5. except ImportError:
  6. raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
  7. MODEL_NAME = "ViT-B/32"
  8. class ClipProcessor:
  9. @staticmethod
  10. def load_model():
  11. """Load data from a director of images."""
  12. device = "cuda" if torch.cuda.is_available() else "cpu"
  13. # load model and image preprocessing
  14. model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
  15. return model, preprocess
  16. @staticmethod
  17. def get_image_features(image_url, model, preprocess):
  18. """
  19. Applies the CLIP model to evaluate the vector representation of the supplied image
  20. """
  21. device = "cuda" if torch.cuda.is_available() else "cpu"
  22. try:
  23. # load image
  24. image = Image.open(image_url)
  25. except FileNotFoundError:
  26. raise FileNotFoundError("The supplied file does not exist`")
  27. except UnidentifiedImageError:
  28. raise UnidentifiedImageError("The supplied file is not an image`")
  29. # pre-process image
  30. processed_image = preprocess(image).unsqueeze(0).to(device)
  31. with torch.no_grad():
  32. image_features = model.encode_image(processed_image)
  33. image_features /= image_features.norm(dim=-1, keepdim=True)
  34. image_features = image_features.cpu().detach().numpy().tolist()[0]
  35. meta_data = {
  36. "url": image_url
  37. }
  38. return {
  39. "content": image_url,
  40. "embedding": image_features,
  41. "meta_data": meta_data
  42. }
  43. @staticmethod
  44. def get_text_features(query):
  45. """
  46. Applies the CLIP model to evaluate the vector representation of the supplied text
  47. """
  48. device = "cuda" if torch.cuda.is_available() else "cpu"
  49. model, preprocess = ClipProcessor.load_model()
  50. text = clip.tokenize(query).to(device)
  51. with torch.no_grad():
  52. text_features = model.encode_text(text)
  53. text_features /= text_features.norm(dim=-1, keepdim=True)
  54. return text_features.cpu().numpy().tolist()[0]