Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import streamlit as st | |
import torchvision.utils as vutils | |
import matplotlib.pyplot as plt | |
class Generator(nn.Module): | |
def __init__(self, channels_noise, channels_img, features_g): | |
super(Generator, self).__init__() | |
self.net = nn.Sequential( | |
# Input: N x channels_noise x 1 x 1 | |
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4 | |
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8 | |
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16 | |
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32 | |
nn.ConvTranspose2d( | |
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1 | |
), | |
# Output: N x channels_img x 64 x 64 | |
nn.Tanh(), | |
) | |
def _block(self, in_channels, out_channels, kernel_size, stride, padding): | |
return nn.Sequential( | |
nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias=False, | |
), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(), | |
) | |
def forward(self, x): | |
return self.net(x) | |
# Load the trained model | |
def load_model(model_path="gan_final.pth", noise_dim=100, device="cpu"): | |
checkpoint = torch.load(model_path, map_location=device) | |
# Recreate generator model | |
gen = Generator(channels_noise=noise_dim, channels_img=3, features_g=64).to(device) | |
gen.load_state_dict(checkpoint["generator"]) | |
gen.eval() | |
return gen | |
# Function to generate images | |
def generate_images(generator, num_images=1, noise_dim=100, device="cpu"): | |
noise = torch.randn(num_images, noise_dim, 1, 1, device=device) | |
with torch.no_grad(): | |
fake_images = generator(noise).cpu() | |
# Denormalize from [-1,1] to [0,1] | |
fake_images = (fake_images * 0.5) + 0.5 | |
return fake_images | |
# Streamlit UI | |
st.title("GAN Image Generator π¨") | |
st.write("Generate images using a trained GAN model.") | |
# Load the model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
generator = load_model(device=device) | |
# User input for number of images | |
num_images = st.slider("Select number of images", 1, 8, 4) | |
# Generate button | |
if st.button("Generate Images"): | |
st.write("ποΈ Generating images...") | |
fake_images = generate_images(generator, num_images=num_images, device=device) | |
# Display images | |
fig, ax = plt.subplots(figsize=(num_images, num_images)) | |
ax.axis("off") | |
ax.imshow(vutils.make_grid(fake_images, padding=2, normalize=False).permute(1, 2, 0)) | |
st.pyplot(fig) | |