Spaces:
Runtime error
Runtime error
File size: 2,112 Bytes
14bb247 581506e 3d04a5c 14bb247 581506e 14bb247 581506e 859c3ef 581506e 14bb247 7440015 14bb247 7440015 14bb247 859c3ef 14bb247 581506e 859c3ef 14bb247 859c3ef 581506e 14bb247 893c03b 581506e 3d04a5c 859c3ef aa94b73 581506e aa94b73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from transformers import ViTMAEForPreTraining, ViTImageProcessor
import numpy as np
import torch
import gradio as gr
image_processor = ViTImageProcessor.from_pretrained('andrewbo29/vit-mae-base-formula1')
model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
imagenet_mean = np.array(image_processor.image_mean)
imagenet_std = np.array(image_processor.image_std)
def prep_image(image):
return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
def reconstruct(image):
pixel_values = image_processor.preprocess(image, return_tensors='pt').pixel_values
outputs = model(pixel_values)
y = model.unpatchify(outputs.logits)
y = torch.einsum('nchw->nhwc', y).detach().cpu()
# visualize the mask
mask = outputs.mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
out_orig = prep_image(x[0])
out_masked = prep_image(im_masked[0])
out_rec = prep_image(y[0])
out_rec_vis = prep_image(im_paste[0])
return [(out_orig, 'original'),
(out_masked, 'masked'),
(out_rec, 'reconstruction'),
(out_rec_vis, 'reconstruction + visible')]
with gr.Blocks() as demo:
with gr.Column(variant='panel'):
with gr.Column():
img = gr.Image(
container=False,
type='pil'
)
btn = gr.Button(
'Apply F1 MAE',
scale=0
)
gallery = gr.Gallery(
columns=4,
rows=1,
height='300px',
object_fit='none'
)
btn.click(reconstruct, img, gallery)
if __name__ == "__main__":
demo.launch()
|