Spaces:
Runtime error
Runtime error
patrickramos
commited on
Commit
β’
b991b4f
1
Parent(s):
b66e6ca
Update app.py
Browse files
app.py
CHANGED
@@ -1,20 +1,14 @@
|
|
1 |
-
from transformers import CLIPModel, CLIPProcessor
|
2 |
-
|
3 |
-
MODEL_ID = 'openai/clip-vit-base-patch32' #@param {'type': 'string'}
|
4 |
-
LOAD_IN_8BIT = False #@param {'type': 'boolean'}
|
5 |
-
BATCH_SIZE = 1024 #@param {'type': 'integer'}
|
6 |
-
REVISION = '' #@param {'type': 'string'}
|
7 |
-
REVISION = None if not REVISION else REVISION
|
8 |
-
|
9 |
-
from transformers import CLIPConfig
|
10 |
-
from huggingface_hub import hf_hub_download
|
11 |
-
from safetensors.torch import load_file
|
12 |
-
|
13 |
import os
|
14 |
from huggingface_hub import login
|
15 |
|
16 |
login(os.environ['hf_token'])
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def load_distillclip(model_id, revision=None):
|
19 |
ckpt_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=revision)
|
20 |
config = CLIPConfig.from_pretrained(model_id)
|
@@ -27,34 +21,21 @@ def load_distillclip(model_id, revision=None):
|
|
27 |
bias=True,
|
28 |
)
|
29 |
model.vision_model.pre_layrnorm = nn.Identity()
|
30 |
-
# model.vision_model.post_layernorm = nn.Identity()
|
31 |
print(model.load_state_dict({k.removeprefix('student.'): v for k, v in load_file(ckpt_path).items()}))
|
32 |
-
# model.load_state_dict(load_file(ckpt_path))
|
33 |
return model
|
|
|
34 |
|
|
|
35 |
from torch import nn
|
36 |
-
from accelerate import init_empty_weights, infer_auto_device_map
|
37 |
-
from transformers import CLIPModel, CLIPProcessor
|
38 |
from einops import reduce
|
|
|
39 |
|
40 |
class ZeroShotCLIP(nn.Module):
|
41 |
-
def __init__(self,
|
42 |
super().__init__()
|
43 |
|
44 |
-
self.
|
45 |
-
|
46 |
-
self.model = model.eval()
|
47 |
-
self.processor = processor
|
48 |
-
else:
|
49 |
-
if load_in_8bit:
|
50 |
-
with init_empty_weights():
|
51 |
-
dummy = CLIPModel.from_pretrained(model_id)
|
52 |
-
device_map = infer_auto_device_map(dummy)
|
53 |
-
del dummy
|
54 |
-
self.model = CLIPModel.from_pretrained(model_id, load_in_8bit=True, device_map=device_map)
|
55 |
-
else:
|
56 |
-
self.model = CLIPModel.from_pretrained(model_id).eval()
|
57 |
-
self.processor = CLIPProcessor.from_pretrained(model_id)
|
58 |
self.classes = classes
|
59 |
self.templates = templates
|
60 |
self._init_weights()
|
@@ -63,8 +44,6 @@ class ZeroShotCLIP(nn.Module):
|
|
63 |
def _init_weights(self):
|
64 |
self.model.eval()
|
65 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
66 |
-
if not self.load_in_8bit:
|
67 |
-
self.model = self.model.to(device)
|
68 |
weights = []
|
69 |
for classname in tqdm(self.classes):
|
70 |
prompts = [template.format(classname) for template in self.templates]
|
@@ -76,159 +55,54 @@ class ZeroShotCLIP(nn.Module):
|
|
76 |
weights.append(embeddings)
|
77 |
weights = torch.stack(weights)
|
78 |
self.register_buffer('weights', weights)
|
79 |
-
if not self.load_in_8bit:
|
80 |
-
self.model = self.model.cpu()
|
81 |
|
82 |
@torch.no_grad()
|
83 |
def forward(self, pixel_values):
|
84 |
x = self.model.get_image_features(pixel_values=pixel_values)
|
85 |
x /= x.norm(dim=-1, keepdim=True)
|
86 |
-
return x.mm(self.weights.t())
|
87 |
|
88 |
def preprocess_and_forward(self, x):
|
89 |
-
x = self.processor(images=x)
|
90 |
-
return self(x)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
'a high contrast photo of the {}.',
|
119 |
-
'a bad photo of the {}.',
|
120 |
-
'a good photo of the {}.',
|
121 |
-
'a photo of the small {}.',
|
122 |
-
'a photo of the big {}.',
|
123 |
-
]
|
124 |
-
|
125 |
-
imagenet_templates = [
|
126 |
-
'a bad photo of a {}.',
|
127 |
-
'a photo of many {}.',
|
128 |
-
'a sculpture of a {}.',
|
129 |
-
'a photo of the hard to see {}.',
|
130 |
-
'a low resolution photo of the {}.',
|
131 |
-
'a rendering of a {}.',
|
132 |
-
'graffiti of a {}.',
|
133 |
-
'a bad photo of the {}.',
|
134 |
-
'a cropped photo of the {}.',
|
135 |
-
'a tattoo of a {}.',
|
136 |
-
'the embroidered {}.',
|
137 |
-
'a photo of a hard to see {}.',
|
138 |
-
'a bright photo of a {}.',
|
139 |
-
'a photo of a clean {}.',
|
140 |
-
'a photo of a dirty {}.',
|
141 |
-
'a dark photo of the {}.',
|
142 |
-
'a drawing of a {}.',
|
143 |
-
'a photo of my {}.',
|
144 |
-
'the plastic {}.',
|
145 |
-
'a photo of the cool {}.',
|
146 |
-
'a close-up photo of a {}.',
|
147 |
-
'a black and white photo of the {}.',
|
148 |
-
'a painting of the {}.',
|
149 |
-
'a painting of a {}.',
|
150 |
-
'a pixelated photo of the {}.',
|
151 |
-
'a sculpture of the {}.',
|
152 |
-
'a bright photo of the {}.',
|
153 |
-
'a cropped photo of a {}.',
|
154 |
-
'a plastic {}.',
|
155 |
-
'a photo of the dirty {}.',
|
156 |
-
'a jpeg corrupted photo of a {}.',
|
157 |
-
'a blurry photo of the {}.',
|
158 |
-
'a photo of the {}.',
|
159 |
-
'a good photo of the {}.',
|
160 |
-
'a rendering of the {}.',
|
161 |
-
'a {} in a video game.',
|
162 |
-
'a photo of one {}.',
|
163 |
-
'a doodle of a {}.',
|
164 |
-
'a close-up photo of the {}.',
|
165 |
-
'a photo of a {}.',
|
166 |
-
'the origami {}.',
|
167 |
-
'the {} in a video game.',
|
168 |
-
'a sketch of a {}.',
|
169 |
-
'a doodle of the {}.',
|
170 |
-
'a origami {}.',
|
171 |
-
'a low resolution photo of a {}.',
|
172 |
-
'the toy {}.',
|
173 |
-
'a rendition of the {}.',
|
174 |
-
'a photo of the clean {}.',
|
175 |
-
'a photo of a large {}.',
|
176 |
-
'a rendition of a {}.',
|
177 |
-
'a photo of a nice {}.',
|
178 |
-
'a photo of a weird {}.',
|
179 |
-
'a blurry photo of a {}.',
|
180 |
-
'a cartoon {}.',
|
181 |
-
'art of a {}.',
|
182 |
-
'a sketch of the {}.',
|
183 |
-
'a embroidered {}.',
|
184 |
-
'a pixelated photo of a {}.',
|
185 |
-
'itap of the {}.',
|
186 |
-
'a jpeg corrupted photo of the {}.',
|
187 |
-
'a good photo of a {}.',
|
188 |
-
'a plushie {}.',
|
189 |
-
'a photo of the nice {}.',
|
190 |
-
'a photo of the small {}.',
|
191 |
-
'a photo of the weird {}.',
|
192 |
-
'the cartoon {}.',
|
193 |
-
'art of the {}.',
|
194 |
-
'a drawing of the {}.',
|
195 |
-
'a photo of the large {}.',
|
196 |
-
'a black and white photo of a {}.',
|
197 |
-
'the plushie {}.',
|
198 |
-
'a dark photo of a {}.',
|
199 |
-
'itap of a {}.',
|
200 |
-
'graffiti of the {}.',
|
201 |
-
'a toy {}.',
|
202 |
-
'itap of my {}.',
|
203 |
-
'a photo of a cool {}.',
|
204 |
-
'a photo of a small {}.',
|
205 |
-
'a tattoo of the {}.',
|
206 |
-
]
|
207 |
-
|
208 |
-
dashcam_templates = [
|
209 |
-
'a dashcam recording of {}.',
|
210 |
-
'a picture of {}.',
|
211 |
-
'a recording of {}.'
|
212 |
-
]
|
213 |
-
|
214 |
-
stl10_templates = [
|
215 |
-
'a photo of a {}.',
|
216 |
-
'a photo of the {}.',
|
217 |
-
]
|
218 |
-
|
219 |
-
oxfordpets_templates = [
|
220 |
-
'a photo of a {}, a type of pet.',
|
221 |
-
]
|
222 |
-
|
223 |
-
def predict(image, texts):
|
224 |
-
texts = texts.split(', ')
|
225 |
-
out = pipe(image, candidate_labels=texts)
|
226 |
-
return {d['label']: d['score'] for d in out}
|
227 |
|
228 |
demo = gr.Interface(
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
)
|
233 |
-
|
234 |
-
demo.launch(debug=True, share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
from huggingface_hub import login
|
3 |
|
4 |
login(os.environ['hf_token'])
|
5 |
|
6 |
+
|
7 |
+
from transformers import CLIPConfig, CLIPModel
|
8 |
+
from torch import nn
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from safetensors.torch import load_file
|
11 |
+
|
12 |
def load_distillclip(model_id, revision=None):
|
13 |
ckpt_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=revision)
|
14 |
config = CLIPConfig.from_pretrained(model_id)
|
|
|
21 |
bias=True,
|
22 |
)
|
23 |
model.vision_model.pre_layrnorm = nn.Identity()
|
|
|
24 |
print(model.load_state_dict({k.removeprefix('student.'): v for k, v in load_file(ckpt_path).items()}))
|
|
|
25 |
return model
|
26 |
+
|
27 |
|
28 |
+
import torch
|
29 |
from torch import nn
|
|
|
|
|
30 |
from einops import reduce
|
31 |
+
from tqdm.auto import tqdm
|
32 |
|
33 |
class ZeroShotCLIP(nn.Module):
|
34 |
+
def __init__(self, model=None, processor=None, classes=[], templates=[], load_in_8bit=False):
|
35 |
super().__init__()
|
36 |
|
37 |
+
self.model = model.eval()
|
38 |
+
self.processor = processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
self.classes = classes
|
40 |
self.templates = templates
|
41 |
self._init_weights()
|
|
|
44 |
def _init_weights(self):
|
45 |
self.model.eval()
|
46 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
47 |
weights = []
|
48 |
for classname in tqdm(self.classes):
|
49 |
prompts = [template.format(classname) for template in self.templates]
|
|
|
55 |
weights.append(embeddings)
|
56 |
weights = torch.stack(weights)
|
57 |
self.register_buffer('weights', weights)
|
|
|
|
|
58 |
|
59 |
@torch.no_grad()
|
60 |
def forward(self, pixel_values):
|
61 |
x = self.model.get_image_features(pixel_values=pixel_values)
|
62 |
x /= x.norm(dim=-1, keepdim=True)
|
63 |
+
return x.mm(self.weights.t()) * 100.00000762939453
|
64 |
|
65 |
def preprocess_and_forward(self, x):
|
66 |
+
x = self.processor(images=x, return_tensors='pt')
|
67 |
+
return self(x['pixel_values'])
|
68 |
+
|
69 |
+
|
70 |
+
from transformers import CLIPProcessor
|
71 |
+
|
72 |
+
model = load_distillclip('Ramos-Ramos/distillclip')
|
73 |
+
processor = CLIPProcessor.from_pretrained('Ramos-Ramos/distillclip')
|
74 |
+
|
75 |
+
|
76 |
+
def infer(image, classes, templates):
|
77 |
+
classes = [label.strip() for label in classes.split(',')]
|
78 |
+
print(classes)
|
79 |
+
templates = [template.strip() for template in templates.split(';')]
|
80 |
+
print(templates)
|
81 |
+
clip = ZeroShotCLIP(model=model, processor=processor, classes=classes, templates=templates)
|
82 |
+
preds = clip.preprocess_and_forward(image).softmax(dim=1).flatten()
|
83 |
+
return {label: score.item() for label, score in zip(classes, preds)}
|
84 |
+
|
85 |
+
|
86 |
+
import gradio as gr
|
87 |
+
|
88 |
+
title = 'DistillCLIP'
|
89 |
+
description = 'Zero-shot image classification demo with DistillCLIP'
|
90 |
+
article = '''DistillCLIP is a distilled version of [CLIP-ViT/B-32](https://huggingface.co/openai/clip-vit-base-patch32).
|
91 |
+
|
92 |
+
Please refer to the [DistillCLIP model card](https://huggingface.co/Ramos-Ramos/distillclip) for more details on DistillCLIP.
|
93 |
+
|
94 |
+
Note: As multiplying logits by a temperature prior to the softmax can better distinguish final scores, we multiply DistillCLIP's text-image similarity scores by the teacher CLIP's temperature.'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
demo = gr.Interface(
|
97 |
+
fn=infer,
|
98 |
+
inputs=[
|
99 |
+
gr.Image(label='Image', type='pil'),
|
100 |
+
gr.Textbox(label='Classes', placeholder='cat, truck', info='Classes for classification. Separate classes with commas.'),
|
101 |
+
gr.Textbox(label='Prompt/s', placeholder='a photo of a {}.; a blurry photo of a {}.', info='Prompt templates. Use "{}" as placeholder for class. Separate prompts with semi-colons.')
|
102 |
+
],
|
103 |
+
outputs=gr.Label(label='Class scores'),
|
104 |
+
title=title,
|
105 |
+
description=description,
|
106 |
+
article=article
|
107 |
)
|
108 |
+
demo.launch()
|
|