huzey commited on
Commit
86da6bf
1 Parent(s): b4fad70

update dataset

Browse files
Files changed (2) hide show
  1. app.py +131 -58
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,4 +1,10 @@
1
- import spaces
 
 
 
 
 
 
2
  import gradio as gr
3
 
4
  import torch
@@ -8,8 +14,11 @@ import time
8
 
9
  import gradio as gr
10
 
11
- from backbone import extract_features
12
- from ncut_pytorch import NCUT, rgb_from_tsne_3d, rgb_from_umap_3d
 
 
 
13
 
14
 
15
  def compute_ncut(
@@ -24,8 +33,17 @@ def compute_ncut(
24
  perplexity=150,
25
  n_neighbors=150,
26
  min_dist=0.1,
 
27
  ):
28
  logging_str = ""
 
 
 
 
 
 
 
 
29
  start = time.time()
30
  eigvecs, eigvals = NCUT(
31
  num_eig=num_eig,
@@ -33,33 +51,24 @@ def compute_ncut(
33
  device="cuda" if torch.cuda.is_available() else "cpu",
34
  affinity_focal_gamma=affinity_focal_gamma,
35
  knn=knn_ncut,
 
36
  ).fit_transform(features.reshape(-1, features.shape[-1]))
37
  # print(f"NCUT time: {time.time() - start:.2f}s")
38
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
39
 
40
  start = time.time()
41
- if embedding_method == "UMAP":
42
- X_3d, rgb = rgb_from_umap_3d(
43
- eigvecs,
44
- n_neighbors=n_neighbors,
45
- min_dist=min_dist,
46
- device="cuda" if torch.cuda.is_available() else "cpu",
47
- )
48
- # print(f"UMAP time: {time.time() - start:.2f}s")
49
- logging_str += f"UMAP time: {time.time() - start:.2f}s\n"
50
- elif embedding_method == "t-SNE":
51
- X_3d, rgb = rgb_from_tsne_3d(
52
- eigvecs,
53
- num_sample=num_sample_tsne,
54
- perplexity=perplexity,
55
- knn=knn_tsne,
56
- device="cuda" if torch.cuda.is_available() else "cpu",
57
- )
58
- # print(f"t-SNE time: {time.time() - start:.2f}s")
59
- logging_str += f"t-SNE time: {time.time() - start:.2f}s\n"
60
- else:
61
- raise ValueError(f"Embedding method {embedding_method} not supported.")
62
-
63
  rgb = rgb.reshape(features.shape[:3] + (3,))
64
  return rgb, logging_str
65
 
@@ -76,7 +85,7 @@ def dont_use_too_much_green(image_rgb):
76
 
77
  def to_pil_images(images):
78
  return [
79
- Image.fromarray((image * 255).cpu().numpy().astype(np.uint8)).resize((256, 256), Image.NEAREST)
80
  for image in images
81
  ]
82
 
@@ -103,11 +112,13 @@ def ncut_run(
103
  perplexity=500,
104
  n_neighbors=500,
105
  min_dist=0.1,
 
106
  ):
107
  logging_str = ""
108
  if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
109
  # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
110
- gr.Warning("Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting to {num_sample_tsne-1}.")
 
111
  perplexity = num_sample_tsne - 1
112
  n_neighbors = num_sample_tsne - 1
113
 
@@ -135,26 +146,48 @@ def ncut_run(
135
  perplexity=perplexity,
136
  n_neighbors=n_neighbors,
137
  min_dist=min_dist,
 
138
  )
139
  logging_str += _logging_str
140
  rgb = dont_use_too_much_green(rgb)
141
  return to_pil_images(rgb), logging_str
142
 
143
- @spaces.GPU(duration=13)
144
- def quick_run(images, **kwargs):
145
- return ncut_run(images, **kwargs)
 
 
 
146
 
147
- @spaces.GPU(duration=30)
148
- def long_run(images, **kwargs):
149
- return ncut_run(images, **kwargs)
 
150
 
151
- @spaces.GPU(duration=60)
152
- def longer_run(images, **kwargs):
153
- return ncut_run(images, **kwargs)
154
 
155
- @spaces.GPU(duration=120)
156
- def super_duper_long_run(images, **kwargs):
157
- return ncut_run(images, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  def run_fn(
160
  images,
@@ -171,11 +204,15 @@ def run_fn(
171
  perplexity=500,
172
  n_neighbors=500,
173
  min_dist=0.1,
 
174
  ):
175
  if images is None:
176
  gr.Warning("No images selected.")
177
  return [], "No images selected."
178
 
 
 
 
179
  kwargs = {
180
  "model_name": model_name,
181
  "layer": layer,
@@ -190,6 +227,7 @@ def run_fn(
190
  "perplexity": perplexity,
191
  "n_neighbors": n_neighbors,
192
  "min_dist": min_dist,
 
193
  }
194
  num_images = len(images)
195
  if num_images > 100:
@@ -216,34 +254,48 @@ with gr.Blocks() as demo:
216
  submit_button = gr.Button("🔴RUN", elem_id="submit_button")
217
  clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button')
218
 
219
- gr.Markdown('### Load Examples 👇')
220
- load_images_button = gr.Button("Load", elem_id="load-images-button")
221
- hide_button = gr.Button("Hide", elem_id="hide-button")
222
  example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
 
223
 
224
  hide_button.click(
225
  fn=lambda: gr.update(visible=False),
226
  outputs=example_gallery
227
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  with gr.Column(scale=5, min_width=200):
230
  gr.Markdown('### Output Images')
231
  output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
232
- model_dropdown = gr.Dropdown(["SAM(sam_vit_b)", "MobileSAM", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16)"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name")
233
- layer_slider = gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer")
234
- num_eig_slider = gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
235
- affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper NCUT")
 
236
 
237
- with gr.Accordion("Additional Parameters", open=False):
238
- node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Node type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
239
- num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="num_sample (NCUT)", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
240
- knn_ncut_slider = gr.Slider(1, 100, step=1, label="KNN (NCUT)", value=10, elem_id="knn_ncut", info="Nyström approximation")
241
- embedding_method_dropdown = gr.Dropdown(["t-SNE", "UMAP"], label="Embedding method", value="t-SNE", elem_id="embedding_method")
242
- num_sample_tsne_slider = gr.Slider(100, 1000, step=100, label="num_sample (t-SNE/UMAP)", value=300, elem_id="num_sample_tsne", info="Nyström approximation")
243
- knn_tsne_slider = gr.Slider(1, 100, step=1, label="KNN (t-SNE/UMAP)", value=10, elem_id="knn_tsne", info="Nyström approximation")
244
- perplexity_slider = gr.Slider(10, 500, step=10, label="Perplexity (t-SNE)", value=150, elem_id="perplexity")
245
- n_neighbors_slider = gr.Slider(10, 500, step=10, label="n_neighbors (UMAP)", value=150, elem_id="n_neighbors")
246
- min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="min_dist (UMAP)", value=0.1, elem_id="min_dist")
247
 
248
  # logging text box
249
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
@@ -253,19 +305,40 @@ with gr.Blocks() as demo:
253
 
254
  def empty_input_and_output():
255
  return [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  load_images_button.click(load_default_images, outputs=[input_gallery, output_gallery])
258
  clear_images_button.click(empty_input_and_output, outputs=[input_gallery, output_gallery])
 
259
  submit_button.click(
260
  run_fn,
261
  inputs=[
262
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
263
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
264
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
265
- perplexity_slider, n_neighbors_slider, min_dist_slider
266
  ],
267
  outputs=[output_gallery, logging_text]
268
  )
269
 
270
 
271
- demo.launch()
 
 
 
 
 
 
1
+ # Author: Huzheng Yang
2
+ # %%
3
+ USE_SPACES = False
4
+
5
+ if USE_SPACES:
6
+ import spaces
7
+
8
  import gradio as gr
9
 
10
  import torch
 
14
 
15
  import gradio as gr
16
 
17
+ if USE_SPACES:
18
+ from backbone import extract_features
19
+ else:
20
+ from draft_gradio_backbone import extract_features
21
+ from ncut_pytorch import NCUT, eigenvector_to_rgb
22
 
23
 
24
  def compute_ncut(
 
33
  perplexity=150,
34
  n_neighbors=150,
35
  min_dist=0.1,
36
+ sampling_method="fps",
37
  ):
38
  logging_str = ""
39
+
40
+ num_nodes = np.prod(features.shape[:3])
41
+ if num_nodes / 2 < num_eig:
42
+ # raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
43
+ gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
44
+ num_eig = num_nodes // 2 - 1
45
+ logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
46
+
47
  start = time.time()
48
  eigvecs, eigvals = NCUT(
49
  num_eig=num_eig,
 
51
  device="cuda" if torch.cuda.is_available() else "cpu",
52
  affinity_focal_gamma=affinity_focal_gamma,
53
  knn=knn_ncut,
54
+ sample_method=sampling_method,
55
  ).fit_transform(features.reshape(-1, features.shape[-1]))
56
  # print(f"NCUT time: {time.time() - start:.2f}s")
57
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
58
 
59
  start = time.time()
60
+ _, rgb = eigenvector_to_rgb(
61
+ eigvecs,
62
+ method=embedding_method,
63
+ num_sample=num_sample_tsne,
64
+ perplexity=perplexity,
65
+ n_neighbors=n_neighbors,
66
+ min_distance=min_dist,
67
+ knn=knn_tsne,
68
+ device="cuda" if torch.cuda.is_available() else "cpu",
69
+ )
70
+ logging_str += f"{embedding_method} time: {time.time() - start:.2f}s\n"
71
+
 
 
 
 
 
 
 
 
 
 
72
  rgb = rgb.reshape(features.shape[:3] + (3,))
73
  return rgb, logging_str
74
 
 
85
 
86
  def to_pil_images(images):
87
  return [
88
+ Image.fromarray((image * 255).cpu().numpy().astype(np.uint8)).resize((256, 256), Image.Resampling.NEAREST)
89
  for image in images
90
  ]
91
 
 
112
  perplexity=500,
113
  n_neighbors=500,
114
  min_dist=0.1,
115
+ sampling_method="fps",
116
  ):
117
  logging_str = ""
118
  if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
119
  # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
120
+ gr.Warning("Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting Perplexity to {num_sample_tsne-1}.")
121
+ logging_str += f"Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting Perplexity to {num_sample_tsne-1}.\n"
122
  perplexity = num_sample_tsne - 1
123
  n_neighbors = num_sample_tsne - 1
124
 
 
146
  perplexity=perplexity,
147
  n_neighbors=n_neighbors,
148
  min_dist=min_dist,
149
+ sampling_method=sampling_method,
150
  )
151
  logging_str += _logging_str
152
  rgb = dont_use_too_much_green(rgb)
153
  return to_pil_images(rgb), logging_str
154
 
155
+ def _ncut_run(*args, **kwargs):
156
+ try:
157
+ return ncut_run(*args, **kwargs)
158
+ except Exception as e:
159
+ gr.Error(str(e))
160
+ return [], "Error: " + str(e)
161
 
162
+ if USE_SPACES:
163
+ @spaces.GPU(duration=13)
164
+ def quick_run(*args, **kwargs):
165
+ return _ncut_run(*args, **kwargs)
166
 
167
+ @spaces.GPU(duration=30)
168
+ def long_run(*args, **kwargs):
169
+ return _ncut_run(*args, **kwargs)
170
 
171
+ @spaces.GPU(duration=60)
172
+ def longer_run(*args, **kwargs):
173
+ return _ncut_run(*args, **kwargs)
174
+
175
+ @spaces.GPU(duration=120)
176
+ def super_duper_long_run(*args, **kwargs):
177
+ return _ncut_run(*args, **kwargs)
178
+
179
+ if not USE_SPACES:
180
+ def quick_run(*args, **kwargs):
181
+ return _ncut_run(*args, **kwargs)
182
+
183
+ def long_run(*args, **kwargs):
184
+ return _ncut_run(*args, **kwargs)
185
+
186
+ def longer_run(*args, **kwargs):
187
+ return _ncut_run(*args, **kwargs)
188
+
189
+ def super_duper_long_run(*args, **kwargs):
190
+ return _ncut_run(*args, **kwargs)
191
 
192
  def run_fn(
193
  images,
 
204
  perplexity=500,
205
  n_neighbors=500,
206
  min_dist=0.1,
207
+ sampling_method="fps",
208
  ):
209
  if images is None:
210
  gr.Warning("No images selected.")
211
  return [], "No images selected."
212
 
213
+ if sampling_method == "fps":
214
+ sampling_method = "farthest"
215
+
216
  kwargs = {
217
  "model_name": model_name,
218
  "layer": layer,
 
227
  "perplexity": perplexity,
228
  "n_neighbors": n_neighbors,
229
  "min_dist": min_dist,
230
+ "sampling_method": sampling_method,
231
  }
232
  num_images = len(images)
233
  if num_images > 100:
 
254
  submit_button = gr.Button("🔴RUN", elem_id="submit_button")
255
  clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button')
256
 
257
+ gr.Markdown('### Load from Cloud Dataset 👇')
258
+ load_images_button = gr.Button("Load Example", elem_id="load-images-button")
 
259
  example_gallery = gr.Gallery(value=example_items, label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
260
+ hide_button = gr.Button("Hide Example", elem_id="hide-button")
261
 
262
  hide_button.click(
263
  fn=lambda: gr.update(visible=False),
264
  outputs=example_gallery
265
  )
266
+
267
+ with gr.Accordion("➜ Load from dataset", open=True):
268
+ dataset_names = [
269
+ 'UCSC-VLAA/Recap-COCO-30K',
270
+ 'nateraw/pascal-voc-2012',
271
+ 'johnowhitaker/imagenette2-320',
272
+ 'JapanDegitalMaterial/Places_in_Japan',
273
+ 'Borismile/Anime-dataset',
274
+ ]
275
+ dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="UCSC-VLAA/Recap-COCO-30K", elem_id="dataset")
276
+ num_images_slider = gr.Slider(1, 200, step=1, label="Number of images", value=9, elem_id="num_images")
277
+ random_seed_slider = gr.Number(0, label="Random seed", value=42, elem_id="random_seed")
278
+ load_dataset_button = gr.Button("Load Dataset", elem_id="load-dataset-button")
279
 
280
  with gr.Column(scale=5, min_width=200):
281
  gr.Markdown('### Output Images')
282
  output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
283
+ model_dropdown = gr.Dropdown(["SAM(sam_vit_b)", "MobileSAM", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16)", "MAE(vit_base)"], label="Backbone", value="SAM(sam_vit_b)", elem_id="model_name")
284
+ layer_slider = gr.Slider(0, 11, step=1, label="Backbone: Layer index", value=11, elem_id="layer")
285
+ node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
286
+ num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
287
+ affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
288
 
289
+ with gr.Accordion(" Click to expand: more parameters", open=False):
290
+ num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
291
+ sampling_method_dropdown = gr.Dropdown(["fps", "random"], label="NCUT: Sampling method", value="fps", elem_id="sampling_method")
292
+ knn_ncut_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
293
+ embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_shpere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
294
+ num_sample_tsne_slider = gr.Slider(100, 1000, step=100, label="t-SNE/UMAP: num_sample", value=300, elem_id="num_sample_tsne", info="Nyström approximation")
295
+ knn_tsne_slider = gr.Slider(1, 100, step=1, label="t-SNE/UMAP: KNN", value=10, elem_id="knn_tsne", info="Nyström approximation")
296
+ perplexity_slider = gr.Slider(10, 500, step=10, label="t-SNE: Perplexity", value=150, elem_id="perplexity")
297
+ n_neighbors_slider = gr.Slider(10, 500, step=10, label="UMAP: n_neighbors", value=150, elem_id="n_neighbors")
298
+ min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="UMAP: min_dist", value=0.1, elem_id="min_dist")
299
 
300
  # logging text box
301
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
 
305
 
306
  def empty_input_and_output():
307
  return [], []
308
+
309
+ def load_dataset_images(dataset_name, num_images=10, random_seed=42):
310
+ from datasets import load_dataset
311
+ try:
312
+ dataset = load_dataset(dataset_name)['train']
313
+ except Exception as e:
314
+ gr.Error(f"Error loading dataset {dataset_name}: {e}")
315
+ return None
316
+ if num_images > len(dataset):
317
+ num_images = len(dataset)
318
+ image_idx = np.random.RandomState(random_seed).choice(len(dataset), num_images, replace=False)
319
+ image_idx = image_idx.tolist()
320
+ images = [dataset[i]['image'] for i in image_idx]
321
+ return images
322
+
323
 
324
  load_images_button.click(load_default_images, outputs=[input_gallery, output_gallery])
325
  clear_images_button.click(empty_input_and_output, outputs=[input_gallery, output_gallery])
326
+ load_dataset_button.click(load_dataset_images, inputs=[dataset_dropdown, num_images_slider, random_seed_slider], outputs=[input_gallery])
327
  submit_button.click(
328
  run_fn,
329
  inputs=[
330
  input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
331
  affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
332
  embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
333
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
334
  ],
335
  outputs=[output_gallery, logging_text]
336
  )
337
 
338
 
339
+ if USE_SPACES:
340
+ demo.launch()
341
+ else:
342
+ demo.launch(share=True)
343
+
344
+ # %%
requirements.txt CHANGED
@@ -2,7 +2,7 @@ torch
2
  torchvision
3
  ncut-pytorch
4
  transformers
 
5
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
6
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git
7
-
8
  timm
 
2
  torchvision
3
  ncut-pytorch
4
  transformers
5
+ datasets
6
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
7
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git
 
8
  timm