File size: 906 Bytes
4019e92 8993ad7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import gradio as gr
import torch
from Generator import Generator
from torchvision.utils import save_image
generator = Generator(1)
generator.load_state_dict(torch.load("./generator.pth", map_location=torch.device('cpu')))
generator.eval()
def generate(seed, num_img):
torch.manual_seed(seed)
z = torch.randn(num_img, 100, 1, 1)
fake_img = generator(z)
fake_img = fake_img.detach()
fake_img = fake_img.squeeze()
save_image(fake_img, "fake_img.png", normalize=True)
return 'fake_img.png'
with gr.Blocks() as demo:
gr.Markdown("DCGAN model that generate fake images")
image_input = [
gr.Slider(0, 1000, label='Seed'),
gr.Slider(4, 64, label='Number of images', step=1),
]
image_output = gr.Image()
image_button = gr.Button("Generate")
image_button.click(generate, inputs=image_input, outputs=image_output)
demo.launch() |