BioMike's picture
Update vae.py
c10e08c verified
raw
history blame contribute delete
No virus
2.05 kB
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
from model import model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform1 = transforms.Compose([
transforms.Resize((128, 128)), # Resize the image to 128x128 for the model
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform2 = transforms.Compose([
transforms.Resize((512, 512)) # Resize the image to 512x512 for display
])
def load_image(image):
image = Image.fromarray(image).convert('RGB')
image = transform1(image)
return image.unsqueeze(0).to(device)
def infer_image(image, noise_level):
image = load_image(image)
with torch.no_grad():
mu, logvar = model.encode(image)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std) * noise_level
z = mu + eps * std
decoded_image = model.decode(z)
decoded_image = decoded_image.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.float32) * 0.5 + 0.5
decoded_image = np.clip(decoded_image, 0, 1)
decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8))
decoded_image = transform2(decoded_image)
return np.array(decoded_image)
examples = [
["example_images/image5.png", 0.98],
["example_images/image1.jpg", 0.1],
["example_images/image2.png", 0.5],
["example_images/image3.jpg", 1.0],
]
with gr.Blocks() as vae:
noise_slider = gr.Slider(0, 10, value=0.01, step=0.01, label="Noise Level")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload an image", type="numpy")
with gr.Column():
output_image = gr.Image(label="Reconstructed Image")
input_image.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
noise_slider.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
gr.Examples(examples=examples, inputs=[input_image, noise_slider])