Spaces:
Running
Running
Uploaded files
Browse files- app.py +84 -0
- gan_final.pth +3 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import streamlit as st
|
4 |
+
import torchvision.utils as vutils
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
class Generator(nn.Module):
|
8 |
+
def __init__(self, channels_noise, channels_img, features_g):
|
9 |
+
super(Generator, self).__init__()
|
10 |
+
self.net = nn.Sequential(
|
11 |
+
# Input: N x channels_noise x 1 x 1
|
12 |
+
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
|
13 |
+
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
|
14 |
+
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
|
15 |
+
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
|
16 |
+
nn.ConvTranspose2d(
|
17 |
+
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
|
18 |
+
),
|
19 |
+
# Output: N x channels_img x 64 x 64
|
20 |
+
nn.Tanh(),
|
21 |
+
)
|
22 |
+
|
23 |
+
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
|
24 |
+
return nn.Sequential(
|
25 |
+
nn.ConvTranspose2d(
|
26 |
+
in_channels,
|
27 |
+
out_channels,
|
28 |
+
kernel_size,
|
29 |
+
stride,
|
30 |
+
padding,
|
31 |
+
bias=False,
|
32 |
+
),
|
33 |
+
nn.BatchNorm2d(out_channels),
|
34 |
+
nn.ReLU(),
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.net(x)
|
39 |
+
|
40 |
+
|
41 |
+
# Load the trained model
|
42 |
+
@st.cache_resource
|
43 |
+
def load_model(model_path="gan_final.pth", noise_dim=100, device="cpu"):
|
44 |
+
checkpoint = torch.load(model_path, map_location=device)
|
45 |
+
|
46 |
+
# Recreate generator model
|
47 |
+
gen = Generator(channels_noise=noise_dim, channels_img=3, features_g=64).to(device)
|
48 |
+
gen.load_state_dict(checkpoint["generator"])
|
49 |
+
gen.eval()
|
50 |
+
|
51 |
+
return gen
|
52 |
+
|
53 |
+
# Function to generate images
|
54 |
+
def generate_images(generator, num_images=1, noise_dim=100, device="cpu"):
|
55 |
+
noise = torch.randn(num_images, noise_dim, 1, 1, device=device)
|
56 |
+
with torch.no_grad():
|
57 |
+
fake_images = generator(noise).cpu()
|
58 |
+
|
59 |
+
# Denormalize from [-1,1] to [0,1]
|
60 |
+
fake_images = (fake_images * 0.5) + 0.5
|
61 |
+
|
62 |
+
return fake_images
|
63 |
+
|
64 |
+
# Streamlit UI
|
65 |
+
st.title("GAN Image Generator 🎨")
|
66 |
+
st.write("Generate images using a trained GAN model.")
|
67 |
+
|
68 |
+
# Load the model
|
69 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
70 |
+
generator = load_model(device=device)
|
71 |
+
|
72 |
+
# User input for number of images
|
73 |
+
num_images = st.slider("Select number of images", 1, 8, 4)
|
74 |
+
|
75 |
+
# Generate button
|
76 |
+
if st.button("Generate Images"):
|
77 |
+
st.write("🖌️ Generating images...")
|
78 |
+
fake_images = generate_images(generator, num_images=num_images, device=device)
|
79 |
+
|
80 |
+
# Display images
|
81 |
+
fig, ax = plt.subplots(figsize=(num_images, num_images))
|
82 |
+
ax.axis("off")
|
83 |
+
ax.imshow(vutils.make_grid(fake_images, padding=2, normalize=False).permute(1, 2, 0))
|
84 |
+
st.pyplot(fig)
|
gan_final.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:585d62c44b3618d62e4f46f6c1abadc5fb852685093f615347c6bd6e4b08b93d
|
3 |
+
size 61734830
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
matplotlib
|