Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torchvision | |
from PIL import Image | |
import numpy as np | |
import random | |
from einops import rearrange | |
import matplotlib.pyplot as plt | |
from torchvision.transforms import v2 | |
from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor | |
path = [['images/cat.jpg'], ['images/dog.jpg']] | |
model_name = "vit-t-mae-pretrain.pt" | |
model = torch.load(model_name, map_location='cpu') | |
model.eval() | |
device = torch.device("cpu") | |
model.to(device) | |
transform = v2.Compose([ | |
v2.Resize((32, 32)), | |
v2.ToTensor(), | |
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
# Load and Preprocess the Image | |
def load_image(image_path, transform): | |
img = Image.open(image_path).convert('RGB') | |
# transform = Compose([ToTensor(), Normalize(0.5, 0.5), Resize((32, 32))]) | |
img = transform(img).unsqueeze(0) # Add batch dimension | |
return img | |
def show_image(img, title): | |
img = rearrange(img, "c h w -> h w c") | |
img = (img.cpu().detach().numpy() + 1) / 2 # Normalize to [0, 1] | |
plt.imshow(img) | |
plt.axis('off') | |
plt.title(title) | |
# Visualize a Single Image | |
def visualize_single_image(image_path, image_name, model, device): | |
img = load_image(image_path, transform).to(device) | |
# Run inference | |
model.eval() | |
with torch.no_grad(): | |
predicted_img, mask = model(img) | |
# Convert the tensor back to a displayable image | |
# masked image | |
im_masked = img * (1 - mask) | |
# MAE reconstruction pasted with visible patches | |
im_paste = img * (1 - mask) + predicted_img * mask | |
# make the plt figure larger | |
plt.figure(figsize=(12, 4)) | |
plt.subplot(1, 4, 1) | |
show_image(img[0], "original") | |
plt.subplot(1, 4, 2) | |
show_image(im_masked[0], "masked") | |
plt.subplot(1, 4, 3) | |
show_image(predicted_img[0], "reconstruction") | |
plt.subplot(1, 4, 4) | |
show_image(im_paste[0], "reconstruction + visible") | |
plt.tight_layout() | |
return plt | |
# Example Usage | |
image_path = 'images/dog.jpg' # Replace with the actual path to your image | |
# take the string after the last '/' as the image name | |
image_name = image_path.split('/')[-1].split('.')[0] | |
visualize_single_image(image_path, image_name, model, device) | |
inputs_image = [ | |
gr.components.Image(type="filepath", label="Input Image"), | |
] | |
outputs_image = [ | |
gr.outputs.Image(type="plot", label="Output Image"), | |
] | |
gr.Interface( | |
fn=visualize_single_image, | |
inputs=inputs_image, | |
outputs=outputs_image, | |
title="MAE-ViT Image Reconstruction", | |
description="This is a demo of the MAE-ViT model for image reconstruction.", | |
allow_flagging=False, | |
allow_screenshot=False, | |
allow_remote_access=False, | |
).launch() |