File size: 1,255 Bytes
1801c3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import List

import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

MODEL_DIM = 512


class ClipWrapper:
    def __init__(self):
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    def images2vec(self, images: List[Image.Image]) -> torch.Tensor:
        inputs = self.processor(images=images, return_tensors="pt")
        with torch.no_grad():
            model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            image_embeds = self.model.vision_model(**model_inputs)
            clip_vectors = self.model.visual_projection(image_embeds[1])
        return clip_vectors / clip_vectors.norm(dim=-1, keepdim=True)

    def texts2vec(self, texts: List[str]) -> torch.Tensor:
        inputs = self.processor(text=texts, return_tensors="pt", padding=True)
        with torch.no_grad():
            model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            text_embeds = self.model.text_model(**model_inputs)
            text_vectors = self.model.text_projection(text_embeds[1])
        return text_vectors / text_vectors.norm(dim=-1, keepdim=True)