Anime-GAN / app.py
Sohaib9920's picture
Uploaded files
3b72de3 verified
import torch
import torch.nn as nn
import streamlit as st
import torchvision.utils as vutils
import matplotlib.pyplot as plt
class Generator(nn.Module):
def __init__(self, channels_noise, channels_img, features_g):
super(Generator, self).__init__()
self.net = nn.Sequential(
# Input: N x channels_noise x 1 x 1
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
nn.ConvTranspose2d(
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
),
# Output: N x channels_img x 64 x 64
nn.Tanh(),
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.net(x)
# Load the trained model
@st.cache_resource
def load_model(model_path="gan_final.pth", noise_dim=100, device="cpu"):
checkpoint = torch.load(model_path, map_location=device)
# Recreate generator model
gen = Generator(channels_noise=noise_dim, channels_img=3, features_g=64).to(device)
gen.load_state_dict(checkpoint["generator"])
gen.eval()
return gen
# Function to generate images
def generate_images(generator, num_images=1, noise_dim=100, device="cpu"):
noise = torch.randn(num_images, noise_dim, 1, 1, device=device)
with torch.no_grad():
fake_images = generator(noise).cpu()
# Denormalize from [-1,1] to [0,1]
fake_images = (fake_images * 0.5) + 0.5
return fake_images
# Streamlit UI
st.title("GAN Image Generator 🎨")
st.write("Generate images using a trained GAN model.")
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = load_model(device=device)
# User input for number of images
num_images = st.slider("Select number of images", 1, 8, 4)
# Generate button
if st.button("Generate Images"):
st.write("πŸ–ŒοΈ Generating images...")
fake_images = generate_images(generator, num_images=num_images, device=device)
# Display images
fig, ax = plt.subplots(figsize=(num_images, num_images))
ax.axis("off")
ax.imshow(vutils.make_grid(fake_images, padding=2, normalize=False).permute(1, 2, 0))
st.pyplot(fig)