Spaces:
Sleeping
Sleeping
# pip install gradio | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.utils.data | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
# If `RuntimeError: Error(s) in loading state_dict for Generator` error occurs: | |
omit_module = True | |
# Spatial size of training images. All images will be resized to this | |
# size using a transformer. | |
image_size = 64 | |
# Number of channels in the training images. For color images this is 3 | |
nc = 1 | |
# Size of z latent vector (i.e. size of generator input) | |
nz = 100 | |
# Size of feature maps in generator | |
ngf = 64 | |
# Size of feature maps in discriminator | |
ndf = 64 | |
# Learning rate for optimizers | |
lr = 0.0002 | |
# Beta1 hyperparam for Adam optimizers | |
beta1 = 0.5 | |
# Number of GPUs available. Use 0 for CPU mode. | |
ngpu = 0 | |
device = torch.device("cuda:0" if (ngpu > 0 and torch.cuda.is_available()) else "cpu") | |
# custom weights initialization called on netG and netD | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
nn.init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
nn.init.normal_(m.weight.data, 1.0, 0.02) | |
nn.init.constant_(m.bias.data, 0) | |
class Generator(nn.Module): | |
def __init__(self, ngpu): | |
super(Generator, self).__init__() | |
self.ngpu = ngpu | |
self.main = nn.Sequential( | |
# input is Z, going into a convolution | |
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), | |
nn.BatchNorm2d(ngf * 8), | |
nn.ReLU(True), | |
# state size. (ngf*8) x 4 x 4 | |
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(ngf * 4), | |
nn.ReLU(True), | |
# state size. (ngf*4) x 8 x 8 | |
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(ngf * 2), | |
nn.ReLU(True), | |
# state size. (ngf*2) x 16 x 16 | |
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(ngf), | |
nn.ReLU(True), | |
# state size. (ngf) x 32 x 32 | |
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), | |
nn.Tanh() | |
# state size. (nc) x 64 x 64 | |
) | |
def forward(self, input): | |
return self.main(input) | |
# Create the generator | |
netG = Generator(ngpu).to(device) | |
# Handle multi-gpu if desired | |
if (ngpu > 1) and (device.type == 'cuda'): | |
netG = nn.DataParallel(netG, list(range(ngpu))) | |
# Apply the weights_init function to randomly initialize all weights | |
# to mean=0, stdev=0.02. | |
netG.apply(weights_init) | |
checkpoint = torch.load("checkpoints/epoch1100.ckpt", map_location=torch.device('cpu')) | |
if omit_module: | |
for i in list(checkpoint['netG_state_dict'].keys()): | |
if (str(i).startswith('module.')): | |
checkpoint['netG_state_dict'][i[7:]] = checkpoint['netG_state_dict'].pop(i) | |
netG.load_state_dict(checkpoint['netG_state_dict']) | |
def genImg(): | |
fixed_noise = torch.randn(64, nz, 1, 1, device=device) | |
with torch.no_grad(): | |
fake = netG(fixed_noise).detach().cpu() | |
fake_grid = vutils.make_grid(fake, padding=2, normalize=True) | |
return transforms.functional.to_pil_image(fake_grid) | |
demo = gr.Interface(fn=genImg, inputs=None, outputs="image") | |
demo.launch() |