mlgawd commited on
Commit
eee507e
·
verified ·
1 Parent(s): 2ac3e6a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from model import Generator
4
+ import torchvision.utils as vutils
5
+ import os
6
+ from math import log2
7
+
8
+ # Function to generate images
9
+ def generate_images():
10
+ Z_DIM = 256
11
+ IN_CHANNELS = 256
12
+
13
+ # Load pretrained generator weights
14
+ checkpoint = torch.load("generator.pth", map_location=torch.device('cpu'))
15
+
16
+ # Filter out optimizer-related keys
17
+ state_dict = checkpoint['state_dict']
18
+
19
+ # Load the filtered state dictionary into the model
20
+ generator = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
21
+ generator.load_state_dict(state_dict)
22
+ generator.eval()
23
+
24
+ # Set output directory
25
+ output_dir = "generated_images"
26
+ os.makedirs(output_dir, exist_ok=True)
27
+
28
+ # Generate images
29
+ img_sizes = [256]
30
+ images = []
31
+ for img_size in img_sizes:
32
+ num_steps = int(log2(img_size / 4))
33
+ x = torch.randn((6, Z_DIM, 1, 1)) # Generate a batch of 6 images
34
+ with torch.no_grad():
35
+ z = generator(x, alpha=0.5, steps=num_steps)
36
+
37
+ # Normalize the generated images to the range [-1, 1]
38
+ z = (z + 1) / 2
39
+
40
+ assert z.shape == (6, 3, img_size, img_size)
41
+
42
+ # Append generated images to the list
43
+ for i in range(6):
44
+ images.append(z[i].detach())
45
+
46
+ return images
47
+
48
+ # Main function to create Streamlit web app
49
+ def main():
50
+ st.title('Image Generation with pro-gan 🤖')
51
+ st.write("Click the buttons below to generate images.")
52
+ st.write("Trained on CelebHQ dataset.")
53
+
54
+ # Prompt message about image size
55
+ st.write("Note: Due to limited resources, the model has been trained to generate 256x256 size images. They are still awesome!")
56
+
57
+ # Generate images on button click
58
+ if st.button('Generate Images'):
59
+ images = generate_images()
60
+ # Display the generated images
61
+ for i, image in enumerate(images):
62
+ st.image(image.permute(1, 2, 0).cpu().numpy(), caption=f'Generated Image {i+1}', use_column_width=True)
63
+
64
+ if __name__ == '__main__':
65
+ main()