ryanramos commited on
Commit
b66e6ca
Β·
1 Parent(s): 36c4b4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -14
app.py CHANGED
@@ -1,14 +1,3 @@
1
- import torch
2
- from torch.utils.data import DataLoader
3
- from datasets import load_dataset, IterableDataset
4
- import evaluate
5
- # from sklearn.linear_model import LogisticRegression
6
- import webdataset as wds
7
- from tqdm.auto import tqdm
8
- import lovely_tensors as lt
9
-
10
- lt.monkey_patch()
11
-
12
  from transformers import CLIPModel, CLIPProcessor
13
 
14
  MODEL_ID = 'openai/clip-vit-base-patch32' #@param {'type': 'string'}
@@ -112,16 +101,133 @@ processor = CLIPProcessor.from_pretrained(MODEL_ID)
112
 
113
  pipe = pipeline("zero-shot-image-classification", model=model, feature_extractor=processor.image_processor, tokenizer=processor.tokenizer)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def predict(image, texts):
116
  texts = texts.split(', ')
117
  out = pipe(image, candidate_labels=texts)
118
  return {d['label']: d['score'] for d in out}
119
 
120
- import gradio as gr
121
-
122
  demo = gr.Interface(
123
  fn=predict,
124
- inputs=[gr.Image(type='pil'), gr.Textbox(label='comma separated labels')],
125
  outputs='label',
126
  )
127
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import CLIPModel, CLIPProcessor
2
 
3
  MODEL_ID = 'openai/clip-vit-base-patch32' #@param {'type': 'string'}
 
101
 
102
  pipe = pipeline("zero-shot-image-classification", model=model, feature_extractor=processor.image_processor, tokenizer=processor.tokenizer)
103
 
104
+ cifar_templates = [
105
+ 'a photo of a {}.',
106
+ 'a blurry photo of a {}.',
107
+ 'a black and white photo of a {}.',
108
+ 'a low contrast photo of a {}.',
109
+ 'a high contrast photo of a {}.',
110
+ 'a bad photo of a {}.',
111
+ 'a good photo of a {}.',
112
+ 'a photo of a small {}.',
113
+ 'a photo of a big {}.',
114
+ 'a photo of the {}.',
115
+ 'a blurry photo of the {}.',
116
+ 'a black and white photo of the {}.',
117
+ 'a low contrast photo of the {}.',
118
+ 'a high contrast photo of the {}.',
119
+ 'a bad photo of the {}.',
120
+ 'a good photo of the {}.',
121
+ 'a photo of the small {}.',
122
+ 'a photo of the big {}.',
123
+ ]
124
+
125
+ imagenet_templates = [
126
+ 'a bad photo of a {}.',
127
+ 'a photo of many {}.',
128
+ 'a sculpture of a {}.',
129
+ 'a photo of the hard to see {}.',
130
+ 'a low resolution photo of the {}.',
131
+ 'a rendering of a {}.',
132
+ 'graffiti of a {}.',
133
+ 'a bad photo of the {}.',
134
+ 'a cropped photo of the {}.',
135
+ 'a tattoo of a {}.',
136
+ 'the embroidered {}.',
137
+ 'a photo of a hard to see {}.',
138
+ 'a bright photo of a {}.',
139
+ 'a photo of a clean {}.',
140
+ 'a photo of a dirty {}.',
141
+ 'a dark photo of the {}.',
142
+ 'a drawing of a {}.',
143
+ 'a photo of my {}.',
144
+ 'the plastic {}.',
145
+ 'a photo of the cool {}.',
146
+ 'a close-up photo of a {}.',
147
+ 'a black and white photo of the {}.',
148
+ 'a painting of the {}.',
149
+ 'a painting of a {}.',
150
+ 'a pixelated photo of the {}.',
151
+ 'a sculpture of the {}.',
152
+ 'a bright photo of the {}.',
153
+ 'a cropped photo of a {}.',
154
+ 'a plastic {}.',
155
+ 'a photo of the dirty {}.',
156
+ 'a jpeg corrupted photo of a {}.',
157
+ 'a blurry photo of the {}.',
158
+ 'a photo of the {}.',
159
+ 'a good photo of the {}.',
160
+ 'a rendering of the {}.',
161
+ 'a {} in a video game.',
162
+ 'a photo of one {}.',
163
+ 'a doodle of a {}.',
164
+ 'a close-up photo of the {}.',
165
+ 'a photo of a {}.',
166
+ 'the origami {}.',
167
+ 'the {} in a video game.',
168
+ 'a sketch of a {}.',
169
+ 'a doodle of the {}.',
170
+ 'a origami {}.',
171
+ 'a low resolution photo of a {}.',
172
+ 'the toy {}.',
173
+ 'a rendition of the {}.',
174
+ 'a photo of the clean {}.',
175
+ 'a photo of a large {}.',
176
+ 'a rendition of a {}.',
177
+ 'a photo of a nice {}.',
178
+ 'a photo of a weird {}.',
179
+ 'a blurry photo of a {}.',
180
+ 'a cartoon {}.',
181
+ 'art of a {}.',
182
+ 'a sketch of the {}.',
183
+ 'a embroidered {}.',
184
+ 'a pixelated photo of a {}.',
185
+ 'itap of the {}.',
186
+ 'a jpeg corrupted photo of the {}.',
187
+ 'a good photo of a {}.',
188
+ 'a plushie {}.',
189
+ 'a photo of the nice {}.',
190
+ 'a photo of the small {}.',
191
+ 'a photo of the weird {}.',
192
+ 'the cartoon {}.',
193
+ 'art of the {}.',
194
+ 'a drawing of the {}.',
195
+ 'a photo of the large {}.',
196
+ 'a black and white photo of a {}.',
197
+ 'the plushie {}.',
198
+ 'a dark photo of a {}.',
199
+ 'itap of a {}.',
200
+ 'graffiti of the {}.',
201
+ 'a toy {}.',
202
+ 'itap of my {}.',
203
+ 'a photo of a cool {}.',
204
+ 'a photo of a small {}.',
205
+ 'a tattoo of the {}.',
206
+ ]
207
+
208
+ dashcam_templates = [
209
+ 'a dashcam recording of {}.',
210
+ 'a picture of {}.',
211
+ 'a recording of {}.'
212
+ ]
213
+
214
+ stl10_templates = [
215
+ 'a photo of a {}.',
216
+ 'a photo of the {}.',
217
+ ]
218
+
219
+ oxfordpets_templates = [
220
+ 'a photo of a {}, a type of pet.',
221
+ ]
222
+
223
  def predict(image, texts):
224
  texts = texts.split(', ')
225
  out = pipe(image, candidate_labels=texts)
226
  return {d['label']: d['score'] for d in out}
227
 
 
 
228
  demo = gr.Interface(
229
  fn=predict,
230
+ inputs=[gr.Image(type='pil'), gr.Textbox(label='comma separated labels'), gr.Dropwdown(['CIFAR', 'ImageNet','STL-10', 'Oxford Pets', 'Dashcam'], label='text templates')],
231
  outputs='label',
232
  )
233