egrace479 commited on
Commit
e592678
1 Parent(s): e04fcaa

Sample image return (#8)

Browse files

- Add a to learn more URL at the bottom of the app linking to GitHub and TreeOfLife dataset (6f2766a9fbe1c26c96934dfe41049e05f7ae2cae)
- update requirements for sample image return (9e5fdea1a9b91ad590f961bf1a2cc028d4a6d161)
- Add components for sample image return (1ca4d249a8b99b33200cf026163194de98e167cd)
- Add sample image return functionality to app (16df1e71bfc77b4a496b760a2161d3242177f62a)

.gitattributes CHANGED
@@ -33,7 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
-
37
  *.json filter=lfs diff=lfs merge=lfs -text
38
  *.jpeg filter=lfs diff=lfs merge=lfs -text
39
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
36
  *.json filter=lfs diff=lfs merge=lfs -text
37
  *.jpeg filter=lfs diff=lfs merge=lfs -text
38
  *.png filter=lfs diff=lfs merge=lfs -text
39
+ components/metadata.csv filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -6,12 +6,14 @@ import logging
6
 
7
  import gradio as gr
8
  import numpy as np
 
9
  import torch
10
  import torch.nn.functional as F
11
  from open_clip import create_model, get_tokenizer
12
  from torchvision import transforms
13
 
14
  from templates import openai_imagenet_template
 
15
 
16
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
17
  logging.basicConfig(level=logging.INFO, format=log_format)
@@ -19,6 +21,12 @@ logger = logging.getLogger()
19
 
20
  hf_token = os.getenv("HF_TOKEN")
21
 
 
 
 
 
 
 
22
  model_str = "hf-hub:imageomics/bioclip"
23
  tokenizer_str = "ViT-B-16"
24
 
@@ -123,12 +131,14 @@ def format_name(taxon, common):
123
 
124
 
125
  @torch.no_grad()
126
- def open_domain_classification(img, rank: int) -> dict[str, float]:
127
  """
128
  Predicts from the entire tree of life.
129
  If targeting a higher rank than species, then this function predicts among all
130
  species, then sums up species-level probabilities for the given rank.
131
  """
 
 
132
  img = preprocess_img(img).to(device)
133
  img_features = model.encode_image(img.unsqueeze(0))
134
  img_features = F.normalize(img_features, dim=-1)
@@ -136,21 +146,36 @@ def open_domain_classification(img, rank: int) -> dict[str, float]:
136
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
137
  probs = F.softmax(logits, dim=0)
138
 
139
- # If predicting species, no need to sum probabilities.
140
  if rank + 1 == len(ranks):
141
  topk = probs.topk(k)
142
- return {
143
  format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
144
  }
 
 
 
 
 
 
 
145
 
146
- # Sum up by the rank
147
  output = collections.defaultdict(float)
148
  for i in torch.nonzero(probs > min_prob).squeeze():
149
  output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
150
 
151
  topk_names = heapq.nlargest(k, output, key=output.get)
 
 
 
 
 
 
 
 
152
 
153
- return {name: output[name] for name in topk_names}
 
 
154
 
155
 
156
  def change_output(choice):
@@ -179,9 +204,22 @@ if __name__ == "__main__":
179
  status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
180
 
181
  with gr.Blocks() as app:
182
- img_input = gr.Image()
183
-
184
  with gr.Tab("Open-Ended"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  with gr.Row():
186
  with gr.Column():
187
  rank_dropdown = gr.Dropdown(
@@ -199,32 +237,36 @@ if __name__ == "__main__":
199
  show_label=True,
200
  value=None,
201
  )
202
- # open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
203
 
204
  with gr.Row():
205
  gr.Examples(
206
  examples=open_domain_examples,
207
  inputs=[img_input, rank_dropdown],
208
  cache_examples=True,
209
- fn=open_domain_classification,
210
  outputs=[open_domain_output],
211
  )
212
-
213
- # open_domain_callback = gr.HuggingFaceDatasetSaver(
214
- # hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
215
- # )
216
- # open_domain_callback.setup(
217
- # [img_input, rank_dropdown, open_domain_output],
218
- # flagging_dir="logs/flagged",
219
- # )
220
- # open_domain_flag_btn.click(
221
- # lambda *args: open_domain_callback.flag(args),
222
- # [img_input, rank_dropdown, open_domain_output],
223
- # None,
224
- # preprocess=False,
225
- # )
226
-
 
227
  with gr.Tab("Zero-Shot"):
 
 
 
228
  with gr.Row():
229
  with gr.Column():
230
  classes_txt = gr.Textbox(
@@ -240,45 +282,56 @@ if __name__ == "__main__":
240
  zero_shot_output = gr.Label(
241
  num_top_classes=k, label="Prediction", show_label=True
242
  )
243
- # zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
244
 
245
  with gr.Row():
246
  gr.Examples(
247
  examples=zero_shot_examples,
248
- inputs=[img_input, classes_txt],
249
  cache_examples=True,
250
  fn=zero_shot_classification,
251
  outputs=[zero_shot_output],
252
  )
253
-
254
- # zero_shot_callback = gr.HuggingFaceDatasetSaver(
255
- # hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
256
- # )
257
- # zero_shot_callback.setup(
258
- # [img_input, zero_shot_output], flagging_dir="logs/flagged"
259
- # )
260
- # zero_shot_flag_btn.click(
261
- # lambda *args: zero_shot_callback.flag(args),
262
- # [img_input, zero_shot_output],
263
- # None,
264
- # preprocess=False,
265
- # )
266
-
 
267
  rank_dropdown.change(
268
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
269
  )
270
 
271
  open_domain_btn.click(
272
- fn=open_domain_classification,
273
  inputs=[img_input, rank_dropdown],
274
- outputs=[open_domain_output],
275
  )
276
 
277
  zero_shot_btn.click(
278
  fn=zero_shot_classification,
279
- inputs=[img_input, classes_txt],
280
  outputs=zero_shot_output,
281
  )
 
 
 
 
 
 
 
 
 
 
282
 
283
  app.queue(max_size=20)
284
- app.launch()
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
+ import polars as pl
10
  import torch
11
  import torch.nn.functional as F
12
  from open_clip import create_model, get_tokenizer
13
  from torchvision import transforms
14
 
15
  from templates import openai_imagenet_template
16
+ from components.query import get_sample
17
 
18
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
19
  logging.basicConfig(level=logging.INFO, format=log_format)
 
21
 
22
  hf_token = os.getenv("HF_TOKEN")
23
 
24
+ # For sample images
25
+ METADATA_PATH = "components/metadata.csv"
26
+ # Read page ID as int and filter out smaller ablation duplicated training split
27
+ metadata_df = pl.read_csv(METADATA_PATH, low_memory = False)
28
+ metadata_df = metadata_df.with_columns(pl.col("eol_page_id").cast(pl.Int64))
29
+
30
  model_str = "hf-hub:imageomics/bioclip"
31
  tokenizer_str = "ViT-B-16"
32
 
 
131
 
132
 
133
  @torch.no_grad()
134
+ def open_domain_classification(img, rank: int, return_all=False):
135
  """
136
  Predicts from the entire tree of life.
137
  If targeting a higher rank than species, then this function predicts among all
138
  species, then sums up species-level probabilities for the given rank.
139
  """
140
+
141
+ logger.info(f"Starting open domain classification for rank: {rank}")
142
  img = preprocess_img(img).to(device)
143
  img_features = model.encode_image(img.unsqueeze(0))
144
  img_features = F.normalize(img_features, dim=-1)
 
146
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
147
  probs = F.softmax(logits, dim=0)
148
 
 
149
  if rank + 1 == len(ranks):
150
  topk = probs.topk(k)
151
+ prediction_dict = {
152
  format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
153
  }
154
+ logger.info(f"Top K predictions: {prediction_dict}")
155
+ top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
156
+ logger.info(f"Top prediction name: {top_prediction_name}")
157
+ sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
158
+ if return_all:
159
+ return prediction_dict, sample_img, taxon_url
160
+ return prediction_dict
161
 
 
162
  output = collections.defaultdict(float)
163
  for i in torch.nonzero(probs > min_prob).squeeze():
164
  output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
165
 
166
  topk_names = heapq.nlargest(k, output, key=output.get)
167
+ prediction_dict = {name: output[name] for name in topk_names}
168
+ logger.info(f"Top K names for output: {topk_names}")
169
+ logger.info(f"Prediction dictionary: {prediction_dict}")
170
+
171
+ top_prediction_name = topk_names[0]
172
+ logger.info(f"Top prediction name: {top_prediction_name}")
173
+ sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
174
+ logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
175
 
176
+ if return_all:
177
+ return prediction_dict, sample_img, taxon_url
178
+ return prediction_dict
179
 
180
 
181
  def change_output(choice):
 
204
  status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
205
 
206
  with gr.Blocks() as app:
207
+
 
208
  with gr.Tab("Open-Ended"):
209
+ with gr.Row(variant = "panel", elem_id = "images_panel"):
210
+ with gr.Column():
211
+ img_input = gr.Image(height = 400, sources=["upload"])
212
+
213
+ with gr.Column():
214
+ # display sample image of top predicted taxon
215
+ sample_img = gr.Image(label = "Sample Image of Predicted Taxon",
216
+ height = 400,
217
+ show_download_button = False)
218
+
219
+ taxon_url = gr.HTML(label = "More Information",
220
+ elem_id = "url"
221
+ )
222
+
223
  with gr.Row():
224
  with gr.Column():
225
  rank_dropdown = gr.Dropdown(
 
237
  show_label=True,
238
  value=None,
239
  )
240
+ # open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
241
 
242
  with gr.Row():
243
  gr.Examples(
244
  examples=open_domain_examples,
245
  inputs=[img_input, rank_dropdown],
246
  cache_examples=True,
247
+ fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
248
  outputs=[open_domain_output],
249
  )
250
+ '''
251
+ # Flagging Code
252
+ open_domain_callback = gr.HuggingFaceDatasetSaver(
253
+ hf_token, "bioclip-demo-open-domain-mistakes", private=True
254
+ )
255
+ open_domain_callback.setup(
256
+ [img_input, rank_dropdown, open_domain_output],
257
+ flagging_dir="bioclip-demo-open-domain-mistakes/logs/flagged",
258
+ )
259
+ open_domain_flag_btn.click(
260
+ lambda *args: open_domain_callback.flag(args),
261
+ [img_input, rank_dropdown, open_domain_output],
262
+ None,
263
+ preprocess=False,
264
+ )
265
+ '''
266
  with gr.Tab("Zero-Shot"):
267
+ with gr.Row():
268
+ img_input_zs = gr.Image(height = 400, sources=["upload"])
269
+
270
  with gr.Row():
271
  with gr.Column():
272
  classes_txt = gr.Textbox(
 
282
  zero_shot_output = gr.Label(
283
  num_top_classes=k, label="Prediction", show_label=True
284
  )
285
+ # zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
286
 
287
  with gr.Row():
288
  gr.Examples(
289
  examples=zero_shot_examples,
290
+ inputs=[img_input_zs, classes_txt],
291
  cache_examples=True,
292
  fn=zero_shot_classification,
293
  outputs=[zero_shot_output],
294
  )
295
+ '''
296
+ # Flagging Code
297
+ zero_shot_callback = gr.HuggingFaceDatasetSaver(
298
+ hf_token, "bioclip-demo-zero-shot-mistakes", private=True
299
+ )
300
+ zero_shot_callback.setup(
301
+ [img_input, zero_shot_output], flagging_dir="bioclip-demo-zero-shot-mistakes/logs/flagged"
302
+ )
303
+ zero_shot_flag_btn.click(
304
+ lambda *args: zero_shot_callback.flag(args),
305
+ [img_input, zero_shot_output],
306
+ None,
307
+ preprocess=False,
308
+ )
309
+ '''
310
  rank_dropdown.change(
311
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
312
  )
313
 
314
  open_domain_btn.click(
315
+ fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
316
  inputs=[img_input, rank_dropdown],
317
+ outputs=[open_domain_output, sample_img, taxon_url],
318
  )
319
 
320
  zero_shot_btn.click(
321
  fn=zero_shot_classification,
322
+ inputs=[img_input_zs, classes_txt],
323
  outputs=zero_shot_output,
324
  )
325
+
326
+ # Footer to point out to model and data from app page.
327
+ gr.Markdown(
328
+ """
329
+ For more information on the [BioCLIP Model](https://huggingface.co/imageomics/bioclip) creation, see our [BioCLIP Project GitHub](https://github.com/Imageomics/bioclip), and
330
+ for easier integration of BioCLIP, checkout [pybioclip](https://github.com/Imageomics/pybioclip).
331
+
332
+ To learn more about the data, check out our [TreeOfLife-10M Dataset](https://huggingface.co/datasets/imageomics/TreeOfLife-10M).
333
+ """
334
+ )
335
 
336
  app.queue(max_size=20)
337
+ app.launch(share=True)
components/metadata.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8576f6ca106f35387506369a70df01fb92192a740c3b5da2a12ad8303976aad
3
+ size 233934143
components/metadata_readme.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Bioclip Demo
3
+ emoji: 🐘
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
components/query.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import boto3
3
+ import requests
4
+ import numpy as np
5
+ import polars as pl
6
+ from PIL import Image
7
+ from botocore.config import Config
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # S3 for sample images
13
+ my_config = Config(
14
+ region_name='us-east-1'
15
+ )
16
+ s3_client = boto3.client('s3', config=my_config)
17
+
18
+ # Set basepath for EOL pages for info
19
+ EOL_URL = "https://eol.org/pages/"
20
+ RANKS = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]
21
+
22
+ def get_sample(df, pred_taxon, rank):
23
+ '''
24
+ Function to retrieve a sample image of the predicted taxon and EOL page link for more info.
25
+
26
+ Parameters:
27
+ -----------
28
+ df : DataFrame
29
+ DataFrame with all sample images listed and their filepaths (in "file_path" column).
30
+ pred_taxon : str
31
+ Predicted taxon of the uploaded image.
32
+ rank : int
33
+ Index of rank in RANKS chosen for prediction.
34
+
35
+ Returns:
36
+ --------
37
+ img : PIL.Image
38
+ Sample image of predicted taxon for display.
39
+ eol_page : str
40
+ URL to EOL page for the taxon (may be a lower rank, e.g., species sample).
41
+ '''
42
+ logger.info(f"Getting sample for taxon: {pred_taxon} at rank: {rank}")
43
+ try:
44
+ filepath, eol_page_id, full_name, is_exact = get_sample_data(df, pred_taxon, rank)
45
+ except Exception as e:
46
+ logger.error(f"Error retrieving sample data: {e}")
47
+ return None, f"We encountered the following error trying to retrieve a sample image: {e}."
48
+ if filepath is None:
49
+ logger.warning(f"No sample image found for taxon: {pred_taxon}")
50
+ return None, f"Sorry, our EOL images do not include {pred_taxon}."
51
+
52
+ # Get sample image of selected individual
53
+ try:
54
+ img_src = s3_client.generate_presigned_url('get_object',
55
+ Params={'Bucket': 'treeoflife-10m-sample-images',
56
+ 'Key': filepath}
57
+ )
58
+ img_resp = requests.get(img_src)
59
+ img = Image.open(io.BytesIO(img_resp.content))
60
+ full_eol_url = EOL_URL + eol_page_id
61
+ if is_exact:
62
+ eol_page = f"<p>Check out the EOL entry for {pred_taxon} to learn more: <a href={full_eol_url} target='_blank'>{full_eol_url}</a>.</p>"
63
+ else:
64
+ eol_page = f"<p>Check out an example EOL entry within {pred_taxon} to learn more: {full_name} <a href={full_eol_url} target='_blank'>{full_eol_url}</a>.</p>"
65
+ logger.info(f"Successfully retrieved sample image and EOL page for {pred_taxon}")
66
+ return img, eol_page
67
+ except Exception as e:
68
+ logger.error(f"Error retrieving sample image: {e}")
69
+ return None, f"We encountered the following error trying to retrieve a sample image: {e}."
70
+
71
+ def get_sample_data(df, pred_taxon, rank):
72
+ '''
73
+ Function to randomly select a sample individual of the given taxon and provide associated native location.
74
+
75
+ Parameters:
76
+ -----------
77
+ df : DataFrame
78
+ DataFrame with all sample images listed and their filepaths (in "file_path" column).
79
+ pred_taxon : str
80
+ Predicted taxon of the uploaded image.
81
+ rank : int
82
+ Index of rank in RANKS chosen for prediction.
83
+
84
+ Returns:
85
+ --------
86
+ filepath : str
87
+ Filepath of selected sample image for predicted taxon.
88
+ eol_page_id : str
89
+ EOL page ID associated with predicted taxon for more information.
90
+ full_name : str
91
+ Full taxonomic name of the selected sample.
92
+ is_exact : bool
93
+ Flag indicating if the match is exact (i.e., with empty lower ranks).
94
+ '''
95
+ for idx in range(rank + 1):
96
+ taxon = RANKS[idx]
97
+ target_taxon = pred_taxon.split(" ")[idx]
98
+ df = df.filter(pl.col(taxon) == target_taxon)
99
+
100
+ if df.shape[0] == 0:
101
+ return None, np.nan, "", False
102
+
103
+ # First, try to find entries with empty lower ranks
104
+ exact_df = df
105
+ for lower_rank in RANKS[rank + 1:]:
106
+ exact_df = exact_df.filter((pl.col(lower_rank).is_null()) | (pl.col(lower_rank) == ""))
107
+
108
+ if exact_df.shape[0] > 0:
109
+ df_filtered = exact_df.sample()
110
+ full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0))
111
+ return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, True
112
+
113
+ # If no exact matches, return any entry with the specified rank
114
+ df_filtered = df.sample()
115
+ full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0)) + " " + " ".join(df_filtered.select(RANKS[rank+1:]).row(0))
116
+ return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, False
components/sync_samples_to_s3.bash ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ <<COMMENT
4
+ Usage:
5
+ bash sync_samples_to_s3.bash <BASE_DIR>
6
+
7
+ Dependencies:
8
+ - awscli (https://aws.amazon.com/cli/)
9
+ Credentials to export as environment variables:
10
+ - AWS_ACCESS_KEY_ID
11
+ - AWS_SECRET_ACCESS_KEY
12
+ COMMENT
13
+
14
+ # Check if a valid directory is provided as an argument
15
+ if [ -z "$1" ]; then
16
+ echo "Usage: $0 <BASE_DIR>"
17
+ exit 1
18
+ fi
19
+
20
+ if [ ! -d "$1" ]; then
21
+ echo "Error: $1 is not a valid directory"
22
+ exit 1
23
+ fi
24
+
25
+ BASE_DIR="$1"
26
+ S3_BUCKET="s3://treeoflife-10m-sample-images"
27
+
28
+ # Loop through all directories and sync them to S3
29
+ for dir in $BASE_DIR/*; do
30
+ if [ -d "$dir" ]; then
31
+ dir_name=$(basename "$dir")
32
+ aws s3 sync "$dir" "$S3_BUCKET/$dir_name/"
33
+ fi
34
+ done
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  open_clip_torch
2
  torchvision
3
  torch
4
- gradio
 
 
 
 
1
  open_clip_torch
2
  torchvision
3
  torch
4
+ gradio
5
+ polars
6
+ pillow
7
+ boto3