Sohaib9920 commited on
Commit
3b72de3
·
verified ·
1 Parent(s): 73c7fcc

Uploaded files

Browse files
Files changed (3) hide show
  1. app.py +84 -0
  2. gan_final.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import streamlit as st
4
+ import torchvision.utils as vutils
5
+ import matplotlib.pyplot as plt
6
+
7
+ class Generator(nn.Module):
8
+ def __init__(self, channels_noise, channels_img, features_g):
9
+ super(Generator, self).__init__()
10
+ self.net = nn.Sequential(
11
+ # Input: N x channels_noise x 1 x 1
12
+ self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
13
+ self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
14
+ self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
15
+ self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
16
+ nn.ConvTranspose2d(
17
+ features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
18
+ ),
19
+ # Output: N x channels_img x 64 x 64
20
+ nn.Tanh(),
21
+ )
22
+
23
+ def _block(self, in_channels, out_channels, kernel_size, stride, padding):
24
+ return nn.Sequential(
25
+ nn.ConvTranspose2d(
26
+ in_channels,
27
+ out_channels,
28
+ kernel_size,
29
+ stride,
30
+ padding,
31
+ bias=False,
32
+ ),
33
+ nn.BatchNorm2d(out_channels),
34
+ nn.ReLU(),
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.net(x)
39
+
40
+
41
+ # Load the trained model
42
+ @st.cache_resource
43
+ def load_model(model_path="gan_final.pth", noise_dim=100, device="cpu"):
44
+ checkpoint = torch.load(model_path, map_location=device)
45
+
46
+ # Recreate generator model
47
+ gen = Generator(channels_noise=noise_dim, channels_img=3, features_g=64).to(device)
48
+ gen.load_state_dict(checkpoint["generator"])
49
+ gen.eval()
50
+
51
+ return gen
52
+
53
+ # Function to generate images
54
+ def generate_images(generator, num_images=1, noise_dim=100, device="cpu"):
55
+ noise = torch.randn(num_images, noise_dim, 1, 1, device=device)
56
+ with torch.no_grad():
57
+ fake_images = generator(noise).cpu()
58
+
59
+ # Denormalize from [-1,1] to [0,1]
60
+ fake_images = (fake_images * 0.5) + 0.5
61
+
62
+ return fake_images
63
+
64
+ # Streamlit UI
65
+ st.title("GAN Image Generator 🎨")
66
+ st.write("Generate images using a trained GAN model.")
67
+
68
+ # Load the model
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+ generator = load_model(device=device)
71
+
72
+ # User input for number of images
73
+ num_images = st.slider("Select number of images", 1, 8, 4)
74
+
75
+ # Generate button
76
+ if st.button("Generate Images"):
77
+ st.write("🖌️ Generating images...")
78
+ fake_images = generate_images(generator, num_images=num_images, device=device)
79
+
80
+ # Display images
81
+ fig, ax = plt.subplots(figsize=(num_images, num_images))
82
+ ax.axis("off")
83
+ ax.imshow(vutils.make_grid(fake_images, padding=2, normalize=False).permute(1, 2, 0))
84
+ st.pyplot(fig)
gan_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:585d62c44b3618d62e4f46f6c1abadc5fb852685093f615347c6bd6e4b08b93d
3
+ size 61734830
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ matplotlib