Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,3 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils.data import DataLoader
|
3 |
-
from datasets import load_dataset, IterableDataset
|
4 |
-
import evaluate
|
5 |
-
# from sklearn.linear_model import LogisticRegression
|
6 |
-
import webdataset as wds
|
7 |
-
from tqdm.auto import tqdm
|
8 |
-
import lovely_tensors as lt
|
9 |
-
|
10 |
-
lt.monkey_patch()
|
11 |
-
|
12 |
from transformers import CLIPModel, CLIPProcessor
|
13 |
|
14 |
MODEL_ID = 'openai/clip-vit-base-patch32' #@param {'type': 'string'}
|
@@ -112,16 +101,133 @@ processor = CLIPProcessor.from_pretrained(MODEL_ID)
|
|
112 |
|
113 |
pipe = pipeline("zero-shot-image-classification", model=model, feature_extractor=processor.image_processor, tokenizer=processor.tokenizer)
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
def predict(image, texts):
|
116 |
texts = texts.split(', ')
|
117 |
out = pipe(image, candidate_labels=texts)
|
118 |
return {d['label']: d['score'] for d in out}
|
119 |
|
120 |
-
import gradio as gr
|
121 |
-
|
122 |
demo = gr.Interface(
|
123 |
fn=predict,
|
124 |
-
inputs=[gr.Image(type='pil'), gr.Textbox(label='comma separated labels')],
|
125 |
outputs='label',
|
126 |
)
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import CLIPModel, CLIPProcessor
|
2 |
|
3 |
MODEL_ID = 'openai/clip-vit-base-patch32' #@param {'type': 'string'}
|
|
|
101 |
|
102 |
pipe = pipeline("zero-shot-image-classification", model=model, feature_extractor=processor.image_processor, tokenizer=processor.tokenizer)
|
103 |
|
104 |
+
cifar_templates = [
|
105 |
+
'a photo of a {}.',
|
106 |
+
'a blurry photo of a {}.',
|
107 |
+
'a black and white photo of a {}.',
|
108 |
+
'a low contrast photo of a {}.',
|
109 |
+
'a high contrast photo of a {}.',
|
110 |
+
'a bad photo of a {}.',
|
111 |
+
'a good photo of a {}.',
|
112 |
+
'a photo of a small {}.',
|
113 |
+
'a photo of a big {}.',
|
114 |
+
'a photo of the {}.',
|
115 |
+
'a blurry photo of the {}.',
|
116 |
+
'a black and white photo of the {}.',
|
117 |
+
'a low contrast photo of the {}.',
|
118 |
+
'a high contrast photo of the {}.',
|
119 |
+
'a bad photo of the {}.',
|
120 |
+
'a good photo of the {}.',
|
121 |
+
'a photo of the small {}.',
|
122 |
+
'a photo of the big {}.',
|
123 |
+
]
|
124 |
+
|
125 |
+
imagenet_templates = [
|
126 |
+
'a bad photo of a {}.',
|
127 |
+
'a photo of many {}.',
|
128 |
+
'a sculpture of a {}.',
|
129 |
+
'a photo of the hard to see {}.',
|
130 |
+
'a low resolution photo of the {}.',
|
131 |
+
'a rendering of a {}.',
|
132 |
+
'graffiti of a {}.',
|
133 |
+
'a bad photo of the {}.',
|
134 |
+
'a cropped photo of the {}.',
|
135 |
+
'a tattoo of a {}.',
|
136 |
+
'the embroidered {}.',
|
137 |
+
'a photo of a hard to see {}.',
|
138 |
+
'a bright photo of a {}.',
|
139 |
+
'a photo of a clean {}.',
|
140 |
+
'a photo of a dirty {}.',
|
141 |
+
'a dark photo of the {}.',
|
142 |
+
'a drawing of a {}.',
|
143 |
+
'a photo of my {}.',
|
144 |
+
'the plastic {}.',
|
145 |
+
'a photo of the cool {}.',
|
146 |
+
'a close-up photo of a {}.',
|
147 |
+
'a black and white photo of the {}.',
|
148 |
+
'a painting of the {}.',
|
149 |
+
'a painting of a {}.',
|
150 |
+
'a pixelated photo of the {}.',
|
151 |
+
'a sculpture of the {}.',
|
152 |
+
'a bright photo of the {}.',
|
153 |
+
'a cropped photo of a {}.',
|
154 |
+
'a plastic {}.',
|
155 |
+
'a photo of the dirty {}.',
|
156 |
+
'a jpeg corrupted photo of a {}.',
|
157 |
+
'a blurry photo of the {}.',
|
158 |
+
'a photo of the {}.',
|
159 |
+
'a good photo of the {}.',
|
160 |
+
'a rendering of the {}.',
|
161 |
+
'a {} in a video game.',
|
162 |
+
'a photo of one {}.',
|
163 |
+
'a doodle of a {}.',
|
164 |
+
'a close-up photo of the {}.',
|
165 |
+
'a photo of a {}.',
|
166 |
+
'the origami {}.',
|
167 |
+
'the {} in a video game.',
|
168 |
+
'a sketch of a {}.',
|
169 |
+
'a doodle of the {}.',
|
170 |
+
'a origami {}.',
|
171 |
+
'a low resolution photo of a {}.',
|
172 |
+
'the toy {}.',
|
173 |
+
'a rendition of the {}.',
|
174 |
+
'a photo of the clean {}.',
|
175 |
+
'a photo of a large {}.',
|
176 |
+
'a rendition of a {}.',
|
177 |
+
'a photo of a nice {}.',
|
178 |
+
'a photo of a weird {}.',
|
179 |
+
'a blurry photo of a {}.',
|
180 |
+
'a cartoon {}.',
|
181 |
+
'art of a {}.',
|
182 |
+
'a sketch of the {}.',
|
183 |
+
'a embroidered {}.',
|
184 |
+
'a pixelated photo of a {}.',
|
185 |
+
'itap of the {}.',
|
186 |
+
'a jpeg corrupted photo of the {}.',
|
187 |
+
'a good photo of a {}.',
|
188 |
+
'a plushie {}.',
|
189 |
+
'a photo of the nice {}.',
|
190 |
+
'a photo of the small {}.',
|
191 |
+
'a photo of the weird {}.',
|
192 |
+
'the cartoon {}.',
|
193 |
+
'art of the {}.',
|
194 |
+
'a drawing of the {}.',
|
195 |
+
'a photo of the large {}.',
|
196 |
+
'a black and white photo of a {}.',
|
197 |
+
'the plushie {}.',
|
198 |
+
'a dark photo of a {}.',
|
199 |
+
'itap of a {}.',
|
200 |
+
'graffiti of the {}.',
|
201 |
+
'a toy {}.',
|
202 |
+
'itap of my {}.',
|
203 |
+
'a photo of a cool {}.',
|
204 |
+
'a photo of a small {}.',
|
205 |
+
'a tattoo of the {}.',
|
206 |
+
]
|
207 |
+
|
208 |
+
dashcam_templates = [
|
209 |
+
'a dashcam recording of {}.',
|
210 |
+
'a picture of {}.',
|
211 |
+
'a recording of {}.'
|
212 |
+
]
|
213 |
+
|
214 |
+
stl10_templates = [
|
215 |
+
'a photo of a {}.',
|
216 |
+
'a photo of the {}.',
|
217 |
+
]
|
218 |
+
|
219 |
+
oxfordpets_templates = [
|
220 |
+
'a photo of a {}, a type of pet.',
|
221 |
+
]
|
222 |
+
|
223 |
def predict(image, texts):
|
224 |
texts = texts.split(', ')
|
225 |
out = pipe(image, candidate_labels=texts)
|
226 |
return {d['label']: d['score'] for d in out}
|
227 |
|
|
|
|
|
228 |
demo = gr.Interface(
|
229 |
fn=predict,
|
230 |
+
inputs=[gr.Image(type='pil'), gr.Textbox(label='comma separated labels'), gr.Dropwdown(['CIFAR', 'ImageNet','STL-10', 'Oxford Pets', 'Dashcam'], label='text templates')],
|
231 |
outputs='label',
|
232 |
)
|
233 |
|