Yahia battach commited on
Commit
016de46
1 Parent(s): 7272ff8

edit app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -17
app.py CHANGED
@@ -129,6 +129,53 @@ def format_name(taxon, common):
129
  return f"{taxon} ({common})"
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  @torch.no_grad()
133
  def open_domain_classification(img, rank: int, return_all=False):
134
  """
@@ -136,7 +183,6 @@ def open_domain_classification(img, rank: int, return_all=False):
136
  If targeting a higher rank than species, then this function predicts among all
137
  species, then sums up species-level probabilities for the given rank.
138
  """
139
-
140
  logger.info(f"Starting open domain classification for rank: {rank}")
141
  img = preprocess_img(img).to(device)
142
  img_features = model.encode_image(img.unsqueeze(0))
@@ -148,15 +194,13 @@ def open_domain_classification(img, rank: int, return_all=False):
148
  if rank + 1 == len(ranks):
149
  topk = probs.topk(k)
150
  prediction_dict = {
151
- format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
152
  }
153
  logger.info(f"Top K predictions: {prediction_dict}")
154
- top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
155
- logger.info(f"Top prediction name: {top_prediction_name}")
156
- sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
157
  if return_all:
158
- return prediction_dict, sample_img, taxon_url
159
- return prediction_dict
160
 
161
  output = collections.defaultdict(float)
162
  for i in torch.nonzero(probs > min_prob).squeeze():
@@ -165,18 +209,11 @@ def open_domain_classification(img, rank: int, return_all=False):
165
  topk_names = heapq.nlargest(k, output, key=output.get)
166
  prediction_dict = {name: output[name] for name in topk_names}
167
  logger.info(f"Top K names for output: {topk_names}")
168
- logger.info(f"Prediction dictionary: {prediction_dict}")
169
-
170
- top_prediction_name = topk_names[0]
171
- logger.info(f"Top prediction name: {top_prediction_name}")
172
- sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
173
- logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
174
-
175
  if return_all:
176
- return prediction_dict, sample_img, taxon_url
177
  return prediction_dict
178
 
179
-
180
  def change_output(choice):
181
  return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
182
 
@@ -310,12 +347,19 @@ if __name__ == "__main__":
310
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
311
  )
312
 
 
 
 
 
 
 
313
  open_domain_btn.click(
314
- fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
315
  inputs=[img_input, rank_dropdown],
316
  outputs=[open_domain_output],
317
  )
318
 
 
319
  zero_shot_btn.click(
320
  fn=zero_shot_classification,
321
  inputs=[img_input_zs, classes_txt],
 
129
  return f"{taxon} ({common})"
130
 
131
 
132
+ # @torch.no_grad()
133
+ # def open_domain_classification(img, rank: int, return_all=False):
134
+ # """
135
+ # Predicts from the entire tree of life.
136
+ # If targeting a higher rank than species, then this function predicts among all
137
+ # species, then sums up species-level probabilities for the given rank.
138
+ # """
139
+
140
+ # logger.info(f"Starting open domain classification for rank: {rank}")
141
+ # img = preprocess_img(img).to(device)
142
+ # img_features = model.encode_image(img.unsqueeze(0))
143
+ # img_features = F.normalize(img_features, dim=-1)
144
+
145
+ # logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
146
+ # probs = F.softmax(logits, dim=0)
147
+
148
+ # if rank + 1 == len(ranks):
149
+ # topk = probs.topk(k)
150
+ # prediction_dict = {
151
+ # format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
152
+ # }
153
+ # logger.info(f"Top K predictions: {prediction_dict}")
154
+ # top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
155
+ # logger.info(f"Top prediction name: {top_prediction_name}")
156
+ # sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
157
+ # if return_all:
158
+ # return prediction_dict, sample_img, taxon_url
159
+ # return prediction_dict
160
+
161
+ # output = collections.defaultdict(float)
162
+ # for i in torch.nonzero(probs > min_prob).squeeze():
163
+ # output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
164
+
165
+ # topk_names = heapq.nlargest(k, output, key=output.get)
166
+ # prediction_dict = {name: output[name] for name in topk_names}
167
+ # logger.info(f"Top K names for output: {topk_names}")
168
+ # logger.info(f"Prediction dictionary: {prediction_dict}")
169
+
170
+ # top_prediction_name = topk_names[0]
171
+ # logger.info(f"Top prediction name: {top_prediction_name}")
172
+ # sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
173
+ # logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
174
+
175
+ # if return_all:
176
+ # return prediction_dict, sample_img, taxon_url
177
+ # return prediction_dict
178
+
179
  @torch.no_grad()
180
  def open_domain_classification(img, rank: int, return_all=False):
181
  """
 
183
  If targeting a higher rank than species, then this function predicts among all
184
  species, then sums up species-level probabilities for the given rank.
185
  """
 
186
  logger.info(f"Starting open domain classification for rank: {rank}")
187
  img = preprocess_img(img).to(device)
188
  img_features = model.encode_image(img.unsqueeze(0))
 
194
  if rank + 1 == len(ranks):
195
  topk = probs.topk(k)
196
  prediction_dict = {
197
+ format_name(*txt_names[i]): prob.item() for i, prob in zip(topk.indices, topk.values)
198
  }
199
  logger.info(f"Top K predictions: {prediction_dict}")
200
+
 
 
201
  if return_all:
202
+ return prediction_dict, None, None # Return dummy None values for unused parts
203
+ return prediction_dict # Only return the dictionary for the Label component
204
 
205
  output = collections.defaultdict(float)
206
  for i in torch.nonzero(probs > min_prob).squeeze():
 
209
  topk_names = heapq.nlargest(k, output, key=output.get)
210
  prediction_dict = {name: output[name] for name in topk_names}
211
  logger.info(f"Top K names for output: {topk_names}")
212
+
 
 
 
 
 
 
213
  if return_all:
214
+ return prediction_dict, None, None
215
  return prediction_dict
216
 
 
217
  def change_output(choice):
218
  return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
219
 
 
347
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
348
  )
349
 
350
+ # open_domain_btn.click(
351
+ # fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
352
+ # inputs=[img_input, rank_dropdown],
353
+ # outputs=[open_domain_output],
354
+ # )
355
+
356
  open_domain_btn.click(
357
+ fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
358
  inputs=[img_input, rank_dropdown],
359
  outputs=[open_domain_output],
360
  )
361
 
362
+
363
  zero_shot_btn.click(
364
  fn=zero_shot_classification,
365
  inputs=[img_input_zs, classes_txt],