jiaqingj commited on
Commit
e8accfd
·
1 Parent(s): d241223

Upload clip.py

Browse files
Files changed (1) hide show
  1. clip.py +146 -0
clip.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import requests
3
+ from torch import nn
4
+ from PIL import Image
5
+
6
+ class CLIP(nn.Module):
7
+ def __init__(self, model_name):
8
+ super(CLIP, self).__init__()
9
+ # model name: e.g. openai/clip-vit-base-patch32
10
+ print ('Initializing CLIP model...')
11
+ from transformers import CLIPProcessor, CLIPModel
12
+ self.model = CLIPModel.from_pretrained(model_name)
13
+ self.model.eval()
14
+ self.processor = CLIPProcessor.from_pretrained(model_name)
15
+ from transformers import CLIPTokenizer
16
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
17
+ self.cuda_has_been_checked = False
18
+ print ('CLIP model initialized.')
19
+
20
+ def check_cuda(self):
21
+ self.cuda_available = next(self.model.parameters()).is_cuda
22
+ self.device = next(self.model.parameters()).get_device()
23
+ if self.cuda_available:
24
+ print ('Cuda is available.')
25
+ print ('Device is {}'.format(self.device))
26
+ else:
27
+ print ('Cuda is not available.')
28
+ print ('Device is {}'.format(self.device))
29
+
30
+ @torch.no_grad()
31
+ def compute_image_representation_from_image_path(self, image_path):
32
+ if not self.cuda_has_been_checked:
33
+ self.check_cuda()
34
+ self.cuda_has_been_checked = True
35
+ else:
36
+ pass
37
+ # image_path: the path of the image
38
+ image = Image.open(image_path)
39
+ inputs = self.processor(images=image, return_tensors="pt")
40
+ pixel_values = inputs['pixel_values']
41
+ if self.cuda_available:
42
+ pixel_values = pixel_values.cuda(self.device)
43
+ visual_outputs = self.model.vision_model(pixel_values=pixel_values)
44
+ image_embeds = visual_outputs[1]
45
+ image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim]
46
+ return image_embeds
47
+
48
+ def compute_image_representation_from_image_instance(self, image):
49
+ if not self.cuda_has_been_checked:
50
+ self.check_cuda()
51
+ self.cuda_has_been_checked = True
52
+ else:
53
+ pass
54
+ # image_path: the path of the image
55
+ inputs = self.processor(images=image, return_tensors="pt")
56
+ pixel_values = inputs['pixel_values']
57
+ if self.cuda_available:
58
+ pixel_values = pixel_values.cuda(self.device)
59
+ visual_outputs = self.model.vision_model(pixel_values=pixel_values)
60
+ image_embeds = visual_outputs[1]
61
+ image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim]
62
+ return image_embeds
63
+
64
+ def compute_text_representation(self, text_list):
65
+ if not self.cuda_has_been_checked:
66
+ self.check_cuda()
67
+ self.cuda_has_been_checked = True
68
+ else:
69
+ pass
70
+ # text_list: a list of text
71
+ text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt",
72
+ max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True)
73
+ # self.tokenizer.max_len_single_sentence + 2 = 77
74
+ input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask']
75
+ if self.cuda_available:
76
+ input_ids = input_ids.cuda(self.device)
77
+ attention_mask = attention_mask.cuda(self.device)
78
+ text_outputs = self.model.text_model(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask
81
+ )
82
+ text_embeds = text_outputs[1]
83
+ text_embeds = self.model.text_projection(text_embeds)
84
+ return text_embeds
85
+
86
+ def compute_image_text_similarity_via_embeddings(self, image_embeds, text_embeds):
87
+ '''
88
+ image_embeds: 1 x embed_dim
89
+ text_embeds: len(text_list) x embed_dim
90
+ '''
91
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
92
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
93
+ logit_scale = self.model.logit_scale.exp()
94
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
95
+ logits_per_image = logits_per_text.T
96
+ return logits_per_image.softmax(dim=1), logits_per_image/logit_scale # 1 x len(text_list)
97
+
98
+ def compute_image_text_similarity_via_raw_text(self, image_embeds, text_list):
99
+ text_embeds = self.compute_text_representation(text_list)
100
+ return self.compute_image_text_similarity_via_embeddings(image_embeds, text_embeds)
101
+
102
+ ### -------------------- functions for building index ---------------------- ###
103
+ def compute_batch_index_image_features(self, image_list):
104
+ '''
105
+ # list of image instances
106
+ '''
107
+ if not self.cuda_has_been_checked:
108
+ self.check_cuda()
109
+ self.cuda_has_been_checked = True
110
+ else:
111
+ pass
112
+ # image_path: the path of the image
113
+ inputs = self.processor(images=image_list, return_tensors="pt")
114
+ pixel_values = inputs['pixel_values']
115
+ if self.cuda_available:
116
+ pixel_values = pixel_values.cuda(self.device)
117
+ visual_outputs = self.model.vision_model(pixel_values=pixel_values)
118
+ image_embeds = visual_outputs[1]
119
+ image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim]
120
+ return image_embeds # len(image_list) x embed_dim
121
+
122
+ def compute_batch_index_text_representation(self, text_list):
123
+ if not self.cuda_has_been_checked:
124
+ self.check_cuda()
125
+ self.cuda_has_been_checked = True
126
+ else:
127
+ pass
128
+ # text_list: a list of text
129
+ #text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt")
130
+ text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt",
131
+ max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True)
132
+ input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask']
133
+ if self.cuda_available:
134
+ input_ids = input_ids.cuda(self.device)
135
+ attention_mask = attention_mask.cuda(self.device)
136
+ text_outputs = self.model.text_model(
137
+ input_ids=input_ids,
138
+ attention_mask=attention_mask
139
+ )
140
+ text_embeds = text_outputs[1]
141
+ text_embeds = self.model.text_projection(text_embeds)
142
+ return text_embeds
143
+ #logit_scale = self.model.logit_scale.exp()
144
+ #text_embeds = text_embeds * logit_scale
145
+ #return text_embeds
146
+