Spaces:
Runtime error
Runtime error
Bachmann Roman Christian
commited on
Commit
·
a6ebf2a
1
Parent(s):
3b49518
app.py
CHANGED
@@ -26,6 +26,7 @@ from mpl_toolkits.axes_grid1 import ImageGrid
|
|
26 |
from tqdm import tqdm
|
27 |
import random
|
28 |
from functools import partial
|
|
|
29 |
|
30 |
# import some common detectron2 utilities
|
31 |
from detectron2 import model_zoo
|
@@ -290,7 +291,7 @@ def plot_predictions(input_dict, preds, masks, image_size=224):
|
|
290 |
plt.close()
|
291 |
|
292 |
|
293 |
-
def inference(img, num_rgb, num_depth, num_semseg, seed
|
294 |
im = Image.open(img)
|
295 |
|
296 |
# Center crop and resize RGB
|
@@ -324,21 +325,22 @@ def inference(img, num_rgb, num_depth, num_semseg, seed, perform_sampling, alpha
|
|
324 |
input_dict = {k: v.to(device) for k,v in input_dict.items()}
|
325 |
|
326 |
|
327 |
-
torch.manual_seed(int(seed)) # change seed to resample new mask
|
328 |
-
|
329 |
if perform_sampling:
|
330 |
# Randomly sample masks
|
331 |
|
332 |
-
|
333 |
|
334 |
preds, masks = multimae.forward(
|
335 |
input_dict,
|
336 |
mask_inputs=True, # True if forward pass should sample random masks
|
337 |
num_encoded_tokens=num_tokens,
|
338 |
-
alphas=
|
339 |
)
|
340 |
else:
|
341 |
# Randomly sample masks using the specified number of tokens per modality
|
|
|
|
|
|
|
342 |
task_masks = {domain: torch.ones(1,196).long().to(device) for domain in DOMAINS}
|
343 |
selected_rgb_idxs = torch.randperm(196)[:num_rgb]
|
344 |
selected_depth_idxs = torch.randperm(196)[:num_depth]
|
@@ -365,7 +367,7 @@ title = "MultiMAE"
|
|
365 |
description = "Gradio demo for MultiMAE: Multi-modal Multi-task Masked Autoencoders. \
|
366 |
Upload your own images or try one of the examples below to explore the multi-modal masked reconstruction of a pre-trained MultiMAE model. \
|
367 |
Uploaded images are pseudo labeled using a DPT trained on Omnidata depth, and a Mask2Former trained on COCO. \
|
368 |
-
Choose the number of visible tokens using the sliders below
|
369 |
|
370 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.01678' \
|
371 |
target='_blank'>MultiMAE: Multi-modal Multi-task Masked Autoencoders</a> | \
|
@@ -375,24 +377,18 @@ css = '.output-image{height: 713px !important}'
|
|
375 |
|
376 |
# Example images
|
377 |
os.system("wget https://i.imgur.com/c9ObJdK.jpg")
|
378 |
-
examples = [['c9ObJdK.jpg',
|
379 |
|
380 |
gr.Interface(
|
381 |
fn=inference,
|
382 |
inputs=[
|
383 |
gr.inputs.Image(label='RGB input image', type='filepath'),
|
|
|
|
|
384 |
gr.inputs.Slider(label='Number of RGB input tokens', default=32, step=1, minimum=0, maximum=196),
|
385 |
gr.inputs.Slider(label='Number of depth input tokens', default=32, step=1, minimum=0, maximum=196),
|
386 |
gr.inputs.Slider(label='Number of semantic input tokens', default=32, step=1, minimum=0, maximum=196),
|
387 |
gr.inputs.Number(label='Random seed: Change this to sample different masks', default=0),
|
388 |
-
gr.inputs.Checkbox(label='Randomize the number of tokens: Check this to ignore the above sliders and randomly sample the number \
|
389 |
-
of tokens per modality using the parameters below', default=False),
|
390 |
-
gr.inputs.Slider(label='Symmetric Dirichlet concentration parameter (α > 0). Low values (α << 1.0) result in a sampling behavior, \
|
391 |
-
where most of the time, all visible tokens will be sampled from a single modality. High values \
|
392 |
-
(α >> 1.0) result in similar numbers of tokens being sampled for each modality. α = 1.0 is equivalent \
|
393 |
-
to uniform sampling over the simplex and contains both previous cases and everything in between.',
|
394 |
-
default=1.0, step=0.1, minimum=0.1, maximum=5.0),
|
395 |
-
gr.inputs.Slider(label='Number of input tokens', default=98, step=1, minimum=0, maximum=588),
|
396 |
],
|
397 |
outputs=[
|
398 |
gr.outputs.Image(label='MultiMAE predictions', type='file')
|
|
|
26 |
from tqdm import tqdm
|
27 |
import random
|
28 |
from functools import partial
|
29 |
+
import time
|
30 |
|
31 |
# import some common detectron2 utilities
|
32 |
from detectron2 import model_zoo
|
|
|
291 |
plt.close()
|
292 |
|
293 |
|
294 |
+
def inference(img, num_tokens, perform_sampling, num_rgb, num_depth, num_semseg, seed):
|
295 |
im = Image.open(img)
|
296 |
|
297 |
# Center crop and resize RGB
|
|
|
325 |
input_dict = {k: v.to(device) for k,v in input_dict.items()}
|
326 |
|
327 |
|
|
|
|
|
328 |
if perform_sampling:
|
329 |
# Randomly sample masks
|
330 |
|
331 |
+
torch.manual_seed(int(time.time())) # Random mode is random
|
332 |
|
333 |
preds, masks = multimae.forward(
|
334 |
input_dict,
|
335 |
mask_inputs=True, # True if forward pass should sample random masks
|
336 |
num_encoded_tokens=num_tokens,
|
337 |
+
alphas=1.0
|
338 |
)
|
339 |
else:
|
340 |
# Randomly sample masks using the specified number of tokens per modality
|
341 |
+
|
342 |
+
torch.manual_seed(int(seed)) # change seed to resample new mask
|
343 |
+
|
344 |
task_masks = {domain: torch.ones(1,196).long().to(device) for domain in DOMAINS}
|
345 |
selected_rgb_idxs = torch.randperm(196)[:num_rgb]
|
346 |
selected_depth_idxs = torch.randperm(196)[:num_depth]
|
|
|
367 |
description = "Gradio demo for MultiMAE: Multi-modal Multi-task Masked Autoencoders. \
|
368 |
Upload your own images or try one of the examples below to explore the multi-modal masked reconstruction of a pre-trained MultiMAE model. \
|
369 |
Uploaded images are pseudo labeled using a DPT trained on Omnidata depth, and a Mask2Former trained on COCO. \
|
370 |
+
Choose the number of visible tokens using the sliders below and see how MultiMAE reconstructs the modalities!"
|
371 |
|
372 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.01678' \
|
373 |
target='_blank'>MultiMAE: Multi-modal Multi-task Masked Autoencoders</a> | \
|
|
|
377 |
|
378 |
# Example images
|
379 |
os.system("wget https://i.imgur.com/c9ObJdK.jpg")
|
380 |
+
examples = [['c9ObJdK.jpg', 98, False, 32, 32, 32, 0]]
|
381 |
|
382 |
gr.Interface(
|
383 |
fn=inference,
|
384 |
inputs=[
|
385 |
gr.inputs.Image(label='RGB input image', type='filepath'),
|
386 |
+
gr.inputs.Slider(label='Number of input tokens', default=98, step=1, minimum=0, maximum=588),
|
387 |
+
gr.inputs.Checkbox(label='Manual mode: Check this to manually set the number of input tokens per modality using the sliders below', default=False),
|
388 |
gr.inputs.Slider(label='Number of RGB input tokens', default=32, step=1, minimum=0, maximum=196),
|
389 |
gr.inputs.Slider(label='Number of depth input tokens', default=32, step=1, minimum=0, maximum=196),
|
390 |
gr.inputs.Slider(label='Number of semantic input tokens', default=32, step=1, minimum=0, maximum=196),
|
391 |
gr.inputs.Number(label='Random seed: Change this to sample different masks', default=0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
],
|
393 |
outputs=[
|
394 |
gr.outputs.Image(label='MultiMAE predictions', type='file')
|