therealcyberlord commited on
Commit
0cc3712
·
1 Parent(s): e4eca40

removed files

Browse files
Files changed (2) hide show
  1. App.py +0 -86
  2. Utils.py +0 -40
App.py DELETED
@@ -1,86 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- import DCGAN
4
- import SRGAN
5
- from Utils import color_histogram_mapping, denormalize_images
6
- import torch.nn as nn
7
- import random
8
-
9
- device = torch.device("cpu")
10
-
11
- if torch.cuda.is_available():
12
- device = torch.device("cuda")
13
-
14
- latent_size = 100
15
- display_width = 450
16
- checkpoint_path = "Checkpoints/150epochs.chkpt"
17
-
18
- st.title("Generating Abstract Art")
19
- st.text("start generating (left side bar)")
20
- st.text("Made by Xingyu B.")
21
-
22
- st.sidebar.subheader("Configurations")
23
- seed = st.sidebar.slider('Seed', -10000, 10000, 0)
24
-
25
- num_images = st.sidebar.slider('Number of Images', 1, 10, 1)
26
-
27
- use_srgan = st.sidebar.selectbox(
28
- 'Apply image enhancement',
29
- ('Yes', 'No')
30
- )
31
-
32
- generate = st.sidebar.button("Generate")
33
-
34
-
35
- # caching the expensive model loading
36
-
37
- @st.cache(allow_output_mutation=True)
38
- def load_dcgan():
39
- model = torch.jit.load('Checkpoints/dcgan.pt', map_location=device)
40
- return model
41
-
42
- @st.cache(allow_output_mutation=True)
43
- def load_esrgan():
44
- model_state_dict = torch.load("Checkpoints/esrgan.pt", map_location=device)
45
- return model_state_dict
46
-
47
- # if the user wants to generate something new
48
- if generate:
49
- torch.manual_seed(seed)
50
- random.seed(seed)
51
-
52
- sampled_noise = torch.randn(num_images, latent_size, 1, 1, device=device)
53
- generator = load_dcgan()
54
- generator.eval()
55
-
56
- with torch.no_grad():
57
- fakes = generator(sampled_noise).detach()
58
-
59
- # use srgan for super resolution
60
- if use_srgan == "Yes":
61
- # restore to the checkpoint
62
- st.write("Using DCGAN then ESRGAN upscale...")
63
- esrgan_generator = SRGAN.GeneratorRRDB(channels=3, filters=64, num_res_blocks=23).to(device)
64
- esrgan_checkpoint = load_esrgan()
65
- esrgan_generator.load_state_dict(esrgan_checkpoint)
66
-
67
- esrgan_generator.eval()
68
- with torch.no_grad():
69
- enhanced_fakes = esrgan_generator(fakes).detach().cpu()
70
- color_match = color_histogram_mapping(enhanced_fakes, fakes.cpu())
71
-
72
- for i in range(len(color_match)):
73
- # denormalize and permute to correct color channel
74
- st.image(denormalize_images(color_match[i]).permute(1, 2, 0).numpy(), width=display_width)
75
-
76
-
77
- # default setting -> vanilla dcgan generation
78
- if use_srgan == "No":
79
- fakes = fakes.cpu()
80
- st.write("Using DCGAN Model...")
81
- for i in range(len(fakes)):
82
- st.image(denormalize_images(fakes[i]).permute(1, 2, 0).numpy(), width=display_width)
83
-
84
-
85
-
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Utils.py DELETED
@@ -1,40 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import numpy as np
3
- import torchvision.utils as vutils
4
- import torchvision.transforms as transforms
5
- from skimage.exposure import match_histograms
6
- import torch
7
-
8
- # contains utility functions that we need in the main program
9
-
10
- # matches the color histogram of original and the super resolution output
11
- def color_histogram_mapping(images, references):
12
- matched_list = []
13
- for i in range(len(images)):
14
- matched = match_histograms(images[i].permute(1, 2, 0).numpy(), references[i].permute(1, 2, 0).numpy(),
15
- channel_axis=-1)
16
- matched_list.append(matched)
17
- return torch.tensor(np.array(matched_list)).permute(0, 3, 1, 2)
18
-
19
-
20
- def visualize_generations(seed, images):
21
- plt.figure(figsize=(16, 16))
22
- plt.title(f"Seed: {seed}")
23
- plt.axis("off")
24
- plt.imshow(np.transpose(vutils.make_grid(images, padding=2, nrow=5, normalize=True), (2, 1, 0)))
25
- plt.show()
26
-
27
-
28
- # denormalize the images for proper display
29
- def denormalize_images(images):
30
- mean= [0.5, 0.5, 0.5]
31
- std= [0.5, 0.5, 0.5]
32
- inv_normalize = transforms.Normalize(
33
- mean=[-m / s for m, s in zip(mean, std)],
34
- std=[1 / s for s in std]
35
- )
36
- return inv_normalize(images)
37
-
38
-
39
-
40
-