12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- try:
- import clip
- import torch
- from PIL import Image, UnidentifiedImageError
- except ImportError:
- raise ImportError("Images requires extra dependencies. Install with `pip install embedchain[images]`") from None
- MODEL_NAME = "ViT-B/32"
- class ClipProcessor:
- @staticmethod
- def load_model():
- """Load data from a director of images."""
- device = "cuda" if torch.cuda.is_available() else "cpu"
- # load model and image preprocessing
- model, preprocess = clip.load(MODEL_NAME, device=device, jit=False)
- return model, preprocess
- @staticmethod
- def get_image_features(image_url, model, preprocess):
- """
- Applies the CLIP model to evaluate the vector representation of the supplied image
- """
- device = "cuda" if torch.cuda.is_available() else "cpu"
- try:
- # load image
- image = Image.open(image_url)
- except FileNotFoundError:
- raise FileNotFoundError("The supplied file does not exist`")
- except UnidentifiedImageError:
- raise UnidentifiedImageError("The supplied file is not an image`")
- # pre-process image
- processed_image = preprocess(image).unsqueeze(0).to(device)
- with torch.no_grad():
- image_features = model.encode_image(processed_image)
- image_features /= image_features.norm(dim=-1, keepdim=True)
- image_features = image_features.cpu().detach().numpy().tolist()[0]
- meta_data = {"url": image_url}
- return {"content": image_url, "embedding": image_features, "meta_data": meta_data}
- @staticmethod
- def get_text_features(query):
- """
- Applies the CLIP model to evaluate the vector representation of the supplied text
- """
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model, preprocess = ClipProcessor.load_model()
- text = clip.tokenize(query).to(device)
- with torch.no_grad():
- text_features = model.encode_text(text)
- text_features /= text_features.norm(dim=-1, keepdim=True)
- return text_features.cpu().numpy().tolist()[0]
|