therealcyberlord commited on
Commit
1afdd99
·
1 Parent(s): 0cc3712

renamed files

Browse files
Files changed (2) hide show
  1. app.py +86 -0
  2. utils.py +40 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import DCGAN
4
+ import SRGAN
5
+ from Utils import color_histogram_mapping, denormalize_images
6
+ import torch.nn as nn
7
+ import random
8
+
9
+ device = torch.device("cpu")
10
+
11
+ if torch.cuda.is_available():
12
+ device = torch.device("cuda")
13
+
14
+ latent_size = 100
15
+ display_width = 450
16
+ checkpoint_path = "Checkpoints/150epochs.chkpt"
17
+
18
+ st.title("Generating Abstract Art")
19
+ st.text("start generating (left side bar)")
20
+ st.text("Made by Xingyu B.")
21
+
22
+ st.sidebar.subheader("Configurations")
23
+ seed = st.sidebar.slider('Seed', -10000, 10000, 0)
24
+
25
+ num_images = st.sidebar.slider('Number of Images', 1, 10, 1)
26
+
27
+ use_srgan = st.sidebar.selectbox(
28
+ 'Apply image enhancement',
29
+ ('Yes', 'No')
30
+ )
31
+
32
+ generate = st.sidebar.button("Generate")
33
+
34
+
35
+ # caching the expensive model loading
36
+
37
+ @st.cache(allow_output_mutation=True)
38
+ def load_dcgan():
39
+ model = torch.jit.load('Checkpoints/dcgan.pt', map_location=device)
40
+ return model
41
+
42
+ @st.cache(allow_output_mutation=True)
43
+ def load_esrgan():
44
+ model_state_dict = torch.load("Checkpoints/esrgan.pt", map_location=device)
45
+ return model_state_dict
46
+
47
+ # if the user wants to generate something new
48
+ if generate:
49
+ torch.manual_seed(seed)
50
+ random.seed(seed)
51
+
52
+ sampled_noise = torch.randn(num_images, latent_size, 1, 1, device=device)
53
+ generator = load_dcgan()
54
+ generator.eval()
55
+
56
+ with torch.no_grad():
57
+ fakes = generator(sampled_noise).detach()
58
+
59
+ # use srgan for super resolution
60
+ if use_srgan == "Yes":
61
+ # restore to the checkpoint
62
+ st.write("Using DCGAN then ESRGAN upscale...")
63
+ esrgan_generator = SRGAN.GeneratorRRDB(channels=3, filters=64, num_res_blocks=23).to(device)
64
+ esrgan_checkpoint = load_esrgan()
65
+ esrgan_generator.load_state_dict(esrgan_checkpoint)
66
+
67
+ esrgan_generator.eval()
68
+ with torch.no_grad():
69
+ enhanced_fakes = esrgan_generator(fakes).detach().cpu()
70
+ color_match = color_histogram_mapping(enhanced_fakes, fakes.cpu())
71
+
72
+ for i in range(len(color_match)):
73
+ # denormalize and permute to correct color channel
74
+ st.image(denormalize_images(color_match[i]).permute(1, 2, 0).numpy(), width=display_width)
75
+
76
+
77
+ # default setting -> vanilla dcgan generation
78
+ if use_srgan == "No":
79
+ fakes = fakes.cpu()
80
+ st.write("Using DCGAN Model...")
81
+ for i in range(len(fakes)):
82
+ st.image(denormalize_images(fakes[i]).permute(1, 2, 0).numpy(), width=display_width)
83
+
84
+
85
+
86
+
utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import torchvision.utils as vutils
4
+ import torchvision.transforms as transforms
5
+ from skimage.exposure import match_histograms
6
+ import torch
7
+
8
+ # contains utility functions that we need in the main program
9
+
10
+ # matches the color histogram of original and the super resolution output
11
+ def color_histogram_mapping(images, references):
12
+ matched_list = []
13
+ for i in range(len(images)):
14
+ matched = match_histograms(images[i].permute(1, 2, 0).numpy(), references[i].permute(1, 2, 0).numpy(),
15
+ channel_axis=-1)
16
+ matched_list.append(matched)
17
+ return torch.tensor(np.array(matched_list)).permute(0, 3, 1, 2)
18
+
19
+
20
+ def visualize_generations(seed, images):
21
+ plt.figure(figsize=(16, 16))
22
+ plt.title(f"Seed: {seed}")
23
+ plt.axis("off")
24
+ plt.imshow(np.transpose(vutils.make_grid(images, padding=2, nrow=5, normalize=True), (2, 1, 0)))
25
+ plt.show()
26
+
27
+
28
+ # denormalize the images for proper display
29
+ def denormalize_images(images):
30
+ mean= [0.5, 0.5, 0.5]
31
+ std= [0.5, 0.5, 0.5]
32
+ inv_normalize = transforms.Normalize(
33
+ mean=[-m / s for m, s in zip(mean, std)],
34
+ std=[1 / s for s in std]
35
+ )
36
+ return inv_normalize(images)
37
+
38
+
39
+
40
+