egrace479 commited on
Commit
16df1e7
·
1 Parent(s): 1ca4d24

Add sample image return functionality to app

Browse files

pulls sample image from EOL subset of TreeOfLife-10M

Files changed (1) hide show
  1. app.py +87 -44
app.py CHANGED
@@ -6,12 +6,14 @@ import logging
6
 
7
  import gradio as gr
8
  import numpy as np
 
9
  import torch
10
  import torch.nn.functional as F
11
  from open_clip import create_model, get_tokenizer
12
  from torchvision import transforms
13
 
14
  from templates import openai_imagenet_template
 
15
 
16
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
17
  logging.basicConfig(level=logging.INFO, format=log_format)
@@ -19,6 +21,12 @@ logger = logging.getLogger()
19
 
20
  hf_token = os.getenv("HF_TOKEN")
21
 
 
 
 
 
 
 
22
  model_str = "hf-hub:imageomics/bioclip"
23
  tokenizer_str = "ViT-B-16"
24
 
@@ -123,12 +131,14 @@ def format_name(taxon, common):
123
 
124
 
125
  @torch.no_grad()
126
- def open_domain_classification(img, rank: int) -> dict[str, float]:
127
  """
128
  Predicts from the entire tree of life.
129
  If targeting a higher rank than species, then this function predicts among all
130
  species, then sums up species-level probabilities for the given rank.
131
  """
 
 
132
  img = preprocess_img(img).to(device)
133
  img_features = model.encode_image(img.unsqueeze(0))
134
  img_features = F.normalize(img_features, dim=-1)
@@ -136,21 +146,36 @@ def open_domain_classification(img, rank: int) -> dict[str, float]:
136
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
137
  probs = F.softmax(logits, dim=0)
138
 
139
- # If predicting species, no need to sum probabilities.
140
  if rank + 1 == len(ranks):
141
  topk = probs.topk(k)
142
- return {
143
  format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
144
  }
 
 
 
 
 
 
 
145
 
146
- # Sum up by the rank
147
  output = collections.defaultdict(float)
148
  for i in torch.nonzero(probs > min_prob).squeeze():
149
  output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
150
 
151
  topk_names = heapq.nlargest(k, output, key=output.get)
 
 
 
 
 
 
 
 
152
 
153
- return {name: output[name] for name in topk_names}
 
 
154
 
155
 
156
  def change_output(choice):
@@ -179,9 +204,22 @@ if __name__ == "__main__":
179
  status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
180
 
181
  with gr.Blocks() as app:
182
- img_input = gr.Image()
183
-
184
  with gr.Tab("Open-Ended"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  with gr.Row():
186
  with gr.Column():
187
  rank_dropdown = gr.Dropdown(
@@ -199,32 +237,36 @@ if __name__ == "__main__":
199
  show_label=True,
200
  value=None,
201
  )
202
- # open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
203
 
204
  with gr.Row():
205
  gr.Examples(
206
  examples=open_domain_examples,
207
  inputs=[img_input, rank_dropdown],
208
  cache_examples=True,
209
- fn=open_domain_classification,
210
  outputs=[open_domain_output],
211
  )
212
-
213
- # open_domain_callback = gr.HuggingFaceDatasetSaver(
214
- # hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
215
- # )
216
- # open_domain_callback.setup(
217
- # [img_input, rank_dropdown, open_domain_output],
218
- # flagging_dir="logs/flagged",
219
- # )
220
- # open_domain_flag_btn.click(
221
- # lambda *args: open_domain_callback.flag(args),
222
- # [img_input, rank_dropdown, open_domain_output],
223
- # None,
224
- # preprocess=False,
225
- # )
226
-
 
227
  with gr.Tab("Zero-Shot"):
 
 
 
228
  with gr.Row():
229
  with gr.Column():
230
  classes_txt = gr.Textbox(
@@ -240,43 +282,44 @@ if __name__ == "__main__":
240
  zero_shot_output = gr.Label(
241
  num_top_classes=k, label="Prediction", show_label=True
242
  )
243
- # zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
244
 
245
  with gr.Row():
246
  gr.Examples(
247
  examples=zero_shot_examples,
248
- inputs=[img_input, classes_txt],
249
  cache_examples=True,
250
  fn=zero_shot_classification,
251
  outputs=[zero_shot_output],
252
  )
253
-
254
- # zero_shot_callback = gr.HuggingFaceDatasetSaver(
255
- # hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
256
- # )
257
- # zero_shot_callback.setup(
258
- # [img_input, zero_shot_output], flagging_dir="logs/flagged"
259
- # )
260
- # zero_shot_flag_btn.click(
261
- # lambda *args: zero_shot_callback.flag(args),
262
- # [img_input, zero_shot_output],
263
- # None,
264
- # preprocess=False,
265
- # )
266
-
 
267
  rank_dropdown.change(
268
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
269
  )
270
 
271
  open_domain_btn.click(
272
- fn=open_domain_classification,
273
  inputs=[img_input, rank_dropdown],
274
- outputs=[open_domain_output],
275
  )
276
 
277
  zero_shot_btn.click(
278
  fn=zero_shot_classification,
279
- inputs=[img_input, classes_txt],
280
  outputs=zero_shot_output,
281
  )
282
 
@@ -291,4 +334,4 @@ if __name__ == "__main__":
291
  )
292
 
293
  app.queue(max_size=20)
294
- app.launch()
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
+ import polars as pl
10
  import torch
11
  import torch.nn.functional as F
12
  from open_clip import create_model, get_tokenizer
13
  from torchvision import transforms
14
 
15
  from templates import openai_imagenet_template
16
+ from components.query import get_sample
17
 
18
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
19
  logging.basicConfig(level=logging.INFO, format=log_format)
 
21
 
22
  hf_token = os.getenv("HF_TOKEN")
23
 
24
+ # For sample images
25
+ METADATA_PATH = "components/metadata.csv"
26
+ # Read page ID as int and filter out smaller ablation duplicated training split
27
+ metadata_df = pl.read_csv(METADATA_PATH, low_memory = False)
28
+ metadata_df = metadata_df.with_columns(pl.col("eol_page_id").cast(pl.Int64))
29
+
30
  model_str = "hf-hub:imageomics/bioclip"
31
  tokenizer_str = "ViT-B-16"
32
 
 
131
 
132
 
133
  @torch.no_grad()
134
+ def open_domain_classification(img, rank: int, return_all=False):
135
  """
136
  Predicts from the entire tree of life.
137
  If targeting a higher rank than species, then this function predicts among all
138
  species, then sums up species-level probabilities for the given rank.
139
  """
140
+
141
+ logger.info(f"Starting open domain classification for rank: {rank}")
142
  img = preprocess_img(img).to(device)
143
  img_features = model.encode_image(img.unsqueeze(0))
144
  img_features = F.normalize(img_features, dim=-1)
 
146
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
147
  probs = F.softmax(logits, dim=0)
148
 
 
149
  if rank + 1 == len(ranks):
150
  topk = probs.topk(k)
151
+ prediction_dict = {
152
  format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
153
  }
154
+ logger.info(f"Top K predictions: {prediction_dict}")
155
+ top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
156
+ logger.info(f"Top prediction name: {top_prediction_name}")
157
+ sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
158
+ if return_all:
159
+ return prediction_dict, sample_img, taxon_url
160
+ return prediction_dict
161
 
 
162
  output = collections.defaultdict(float)
163
  for i in torch.nonzero(probs > min_prob).squeeze():
164
  output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
165
 
166
  topk_names = heapq.nlargest(k, output, key=output.get)
167
+ prediction_dict = {name: output[name] for name in topk_names}
168
+ logger.info(f"Top K names for output: {topk_names}")
169
+ logger.info(f"Prediction dictionary: {prediction_dict}")
170
+
171
+ top_prediction_name = topk_names[0]
172
+ logger.info(f"Top prediction name: {top_prediction_name}")
173
+ sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
174
+ logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
175
 
176
+ if return_all:
177
+ return prediction_dict, sample_img, taxon_url
178
+ return prediction_dict
179
 
180
 
181
  def change_output(choice):
 
204
  status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
205
 
206
  with gr.Blocks() as app:
207
+
 
208
  with gr.Tab("Open-Ended"):
209
+ with gr.Row(variant = "panel", elem_id = "images_panel"):
210
+ with gr.Column():
211
+ img_input = gr.Image(height = 400, sources=["upload"])
212
+
213
+ with gr.Column():
214
+ # display sample image of top predicted taxon
215
+ sample_img = gr.Image(label = "Sample Image of Predicted Taxon",
216
+ height = 400,
217
+ show_download_button = False)
218
+
219
+ taxon_url = gr.HTML(label = "More Information",
220
+ elem_id = "url"
221
+ )
222
+
223
  with gr.Row():
224
  with gr.Column():
225
  rank_dropdown = gr.Dropdown(
 
237
  show_label=True,
238
  value=None,
239
  )
240
+ # open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
241
 
242
  with gr.Row():
243
  gr.Examples(
244
  examples=open_domain_examples,
245
  inputs=[img_input, rank_dropdown],
246
  cache_examples=True,
247
+ fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
248
  outputs=[open_domain_output],
249
  )
250
+ '''
251
+ # Flagging Code
252
+ open_domain_callback = gr.HuggingFaceDatasetSaver(
253
+ hf_token, "bioclip-demo-open-domain-mistakes", private=True
254
+ )
255
+ open_domain_callback.setup(
256
+ [img_input, rank_dropdown, open_domain_output],
257
+ flagging_dir="bioclip-demo-open-domain-mistakes/logs/flagged",
258
+ )
259
+ open_domain_flag_btn.click(
260
+ lambda *args: open_domain_callback.flag(args),
261
+ [img_input, rank_dropdown, open_domain_output],
262
+ None,
263
+ preprocess=False,
264
+ )
265
+ '''
266
  with gr.Tab("Zero-Shot"):
267
+ with gr.Row():
268
+ img_input_zs = gr.Image(height = 400, sources=["upload"])
269
+
270
  with gr.Row():
271
  with gr.Column():
272
  classes_txt = gr.Textbox(
 
282
  zero_shot_output = gr.Label(
283
  num_top_classes=k, label="Prediction", show_label=True
284
  )
285
+ # zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
286
 
287
  with gr.Row():
288
  gr.Examples(
289
  examples=zero_shot_examples,
290
+ inputs=[img_input_zs, classes_txt],
291
  cache_examples=True,
292
  fn=zero_shot_classification,
293
  outputs=[zero_shot_output],
294
  )
295
+ '''
296
+ # Flagging Code
297
+ zero_shot_callback = gr.HuggingFaceDatasetSaver(
298
+ hf_token, "bioclip-demo-zero-shot-mistakes", private=True
299
+ )
300
+ zero_shot_callback.setup(
301
+ [img_input, zero_shot_output], flagging_dir="bioclip-demo-zero-shot-mistakes/logs/flagged"
302
+ )
303
+ zero_shot_flag_btn.click(
304
+ lambda *args: zero_shot_callback.flag(args),
305
+ [img_input, zero_shot_output],
306
+ None,
307
+ preprocess=False,
308
+ )
309
+ '''
310
  rank_dropdown.change(
311
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
312
  )
313
 
314
  open_domain_btn.click(
315
+ fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
316
  inputs=[img_input, rank_dropdown],
317
+ outputs=[open_domain_output, sample_img, taxon_url],
318
  )
319
 
320
  zero_shot_btn.click(
321
  fn=zero_shot_classification,
322
+ inputs=[img_input_zs, classes_txt],
323
  outputs=zero_shot_output,
324
  )
325
 
 
334
  )
335
 
336
  app.queue(max_size=20)
337
+ app.launch(share=True)