Bachmann Roman Christian commited on
Commit
a6ebf2a
·
1 Parent(s): 3b49518
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -26,6 +26,7 @@ from mpl_toolkits.axes_grid1 import ImageGrid
26
  from tqdm import tqdm
27
  import random
28
  from functools import partial
 
29
 
30
  # import some common detectron2 utilities
31
  from detectron2 import model_zoo
@@ -290,7 +291,7 @@ def plot_predictions(input_dict, preds, masks, image_size=224):
290
  plt.close()
291
 
292
 
293
- def inference(img, num_rgb, num_depth, num_semseg, seed, perform_sampling, alphas, num_tokens):
294
  im = Image.open(img)
295
 
296
  # Center crop and resize RGB
@@ -324,21 +325,22 @@ def inference(img, num_rgb, num_depth, num_semseg, seed, perform_sampling, alpha
324
  input_dict = {k: v.to(device) for k,v in input_dict.items()}
325
 
326
 
327
- torch.manual_seed(int(seed)) # change seed to resample new mask
328
-
329
  if perform_sampling:
330
  # Randomly sample masks
331
 
332
- alphas = min(10000.0, max(0.00001, float(alphas))) # Clamp alphas to reasonable range
333
 
334
  preds, masks = multimae.forward(
335
  input_dict,
336
  mask_inputs=True, # True if forward pass should sample random masks
337
  num_encoded_tokens=num_tokens,
338
- alphas=alphas
339
  )
340
  else:
341
  # Randomly sample masks using the specified number of tokens per modality
 
 
 
342
  task_masks = {domain: torch.ones(1,196).long().to(device) for domain in DOMAINS}
343
  selected_rgb_idxs = torch.randperm(196)[:num_rgb]
344
  selected_depth_idxs = torch.randperm(196)[:num_depth]
@@ -365,7 +367,7 @@ title = "MultiMAE"
365
  description = "Gradio demo for MultiMAE: Multi-modal Multi-task Masked Autoencoders. \
366
  Upload your own images or try one of the examples below to explore the multi-modal masked reconstruction of a pre-trained MultiMAE model. \
367
  Uploaded images are pseudo labeled using a DPT trained on Omnidata depth, and a Mask2Former trained on COCO. \
368
- Choose the number of visible tokens using the sliders below (or sample them randomly) and see how MultiMAE reconstructs the modalities!"
369
 
370
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.01678' \
371
  target='_blank'>MultiMAE: Multi-modal Multi-task Masked Autoencoders</a> | \
@@ -375,24 +377,18 @@ css = '.output-image{height: 713px !important}'
375
 
376
  # Example images
377
  os.system("wget https://i.imgur.com/c9ObJdK.jpg")
378
- examples = [['c9ObJdK.jpg', 32, 32, 32, 0, True, 1.0, 98]]
379
 
380
  gr.Interface(
381
  fn=inference,
382
  inputs=[
383
  gr.inputs.Image(label='RGB input image', type='filepath'),
 
 
384
  gr.inputs.Slider(label='Number of RGB input tokens', default=32, step=1, minimum=0, maximum=196),
385
  gr.inputs.Slider(label='Number of depth input tokens', default=32, step=1, minimum=0, maximum=196),
386
  gr.inputs.Slider(label='Number of semantic input tokens', default=32, step=1, minimum=0, maximum=196),
387
  gr.inputs.Number(label='Random seed: Change this to sample different masks', default=0),
388
- gr.inputs.Checkbox(label='Randomize the number of tokens: Check this to ignore the above sliders and randomly sample the number \
389
- of tokens per modality using the parameters below', default=False),
390
- gr.inputs.Slider(label='Symmetric Dirichlet concentration parameter (α > 0). Low values (α << 1.0) result in a sampling behavior, \
391
- where most of the time, all visible tokens will be sampled from a single modality. High values \
392
- (α >> 1.0) result in similar numbers of tokens being sampled for each modality. α = 1.0 is equivalent \
393
- to uniform sampling over the simplex and contains both previous cases and everything in between.',
394
- default=1.0, step=0.1, minimum=0.1, maximum=5.0),
395
- gr.inputs.Slider(label='Number of input tokens', default=98, step=1, minimum=0, maximum=588),
396
  ],
397
  outputs=[
398
  gr.outputs.Image(label='MultiMAE predictions', type='file')
 
26
  from tqdm import tqdm
27
  import random
28
  from functools import partial
29
+ import time
30
 
31
  # import some common detectron2 utilities
32
  from detectron2 import model_zoo
 
291
  plt.close()
292
 
293
 
294
+ def inference(img, num_tokens, perform_sampling, num_rgb, num_depth, num_semseg, seed):
295
  im = Image.open(img)
296
 
297
  # Center crop and resize RGB
 
325
  input_dict = {k: v.to(device) for k,v in input_dict.items()}
326
 
327
 
 
 
328
  if perform_sampling:
329
  # Randomly sample masks
330
 
331
+ torch.manual_seed(int(time.time())) # Random mode is random
332
 
333
  preds, masks = multimae.forward(
334
  input_dict,
335
  mask_inputs=True, # True if forward pass should sample random masks
336
  num_encoded_tokens=num_tokens,
337
+ alphas=1.0
338
  )
339
  else:
340
  # Randomly sample masks using the specified number of tokens per modality
341
+
342
+ torch.manual_seed(int(seed)) # change seed to resample new mask
343
+
344
  task_masks = {domain: torch.ones(1,196).long().to(device) for domain in DOMAINS}
345
  selected_rgb_idxs = torch.randperm(196)[:num_rgb]
346
  selected_depth_idxs = torch.randperm(196)[:num_depth]
 
367
  description = "Gradio demo for MultiMAE: Multi-modal Multi-task Masked Autoencoders. \
368
  Upload your own images or try one of the examples below to explore the multi-modal masked reconstruction of a pre-trained MultiMAE model. \
369
  Uploaded images are pseudo labeled using a DPT trained on Omnidata depth, and a Mask2Former trained on COCO. \
370
+ Choose the number of visible tokens using the sliders below and see how MultiMAE reconstructs the modalities!"
371
 
372
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.01678' \
373
  target='_blank'>MultiMAE: Multi-modal Multi-task Masked Autoencoders</a> | \
 
377
 
378
  # Example images
379
  os.system("wget https://i.imgur.com/c9ObJdK.jpg")
380
+ examples = [['c9ObJdK.jpg', 98, False, 32, 32, 32, 0]]
381
 
382
  gr.Interface(
383
  fn=inference,
384
  inputs=[
385
  gr.inputs.Image(label='RGB input image', type='filepath'),
386
+ gr.inputs.Slider(label='Number of input tokens', default=98, step=1, minimum=0, maximum=588),
387
+ gr.inputs.Checkbox(label='Manual mode: Check this to manually set the number of input tokens per modality using the sliders below', default=False),
388
  gr.inputs.Slider(label='Number of RGB input tokens', default=32, step=1, minimum=0, maximum=196),
389
  gr.inputs.Slider(label='Number of depth input tokens', default=32, step=1, minimum=0, maximum=196),
390
  gr.inputs.Slider(label='Number of semantic input tokens', default=32, step=1, minimum=0, maximum=196),
391
  gr.inputs.Number(label='Random seed: Change this to sample different masks', default=0),
 
 
 
 
 
 
 
 
392
  ],
393
  outputs=[
394
  gr.outputs.Image(label='MultiMAE predictions', type='file')