therealcyberlord commited on
Commit
e4eca40
·
1 Parent(s): 380135b

huggingface deployment

Browse files
Files changed (7) hide show
  1. App.py +86 -0
  2. Checkpoints/dcgan.pt +3 -0
  3. Checkpoints/esrgan.pt +3 -0
  4. DCGAN.py +32 -0
  5. SRGAN.py +79 -0
  6. Utils.py +40 -0
  7. requirements.txt +6 -0
App.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import DCGAN
4
+ import SRGAN
5
+ from Utils import color_histogram_mapping, denormalize_images
6
+ import torch.nn as nn
7
+ import random
8
+
9
+ device = torch.device("cpu")
10
+
11
+ if torch.cuda.is_available():
12
+ device = torch.device("cuda")
13
+
14
+ latent_size = 100
15
+ display_width = 450
16
+ checkpoint_path = "Checkpoints/150epochs.chkpt"
17
+
18
+ st.title("Generating Abstract Art")
19
+ st.text("start generating (left side bar)")
20
+ st.text("Made by Xingyu B.")
21
+
22
+ st.sidebar.subheader("Configurations")
23
+ seed = st.sidebar.slider('Seed', -10000, 10000, 0)
24
+
25
+ num_images = st.sidebar.slider('Number of Images', 1, 10, 1)
26
+
27
+ use_srgan = st.sidebar.selectbox(
28
+ 'Apply image enhancement',
29
+ ('Yes', 'No')
30
+ )
31
+
32
+ generate = st.sidebar.button("Generate")
33
+
34
+
35
+ # caching the expensive model loading
36
+
37
+ @st.cache(allow_output_mutation=True)
38
+ def load_dcgan():
39
+ model = torch.jit.load('Checkpoints/dcgan.pt', map_location=device)
40
+ return model
41
+
42
+ @st.cache(allow_output_mutation=True)
43
+ def load_esrgan():
44
+ model_state_dict = torch.load("Checkpoints/esrgan.pt", map_location=device)
45
+ return model_state_dict
46
+
47
+ # if the user wants to generate something new
48
+ if generate:
49
+ torch.manual_seed(seed)
50
+ random.seed(seed)
51
+
52
+ sampled_noise = torch.randn(num_images, latent_size, 1, 1, device=device)
53
+ generator = load_dcgan()
54
+ generator.eval()
55
+
56
+ with torch.no_grad():
57
+ fakes = generator(sampled_noise).detach()
58
+
59
+ # use srgan for super resolution
60
+ if use_srgan == "Yes":
61
+ # restore to the checkpoint
62
+ st.write("Using DCGAN then ESRGAN upscale...")
63
+ esrgan_generator = SRGAN.GeneratorRRDB(channels=3, filters=64, num_res_blocks=23).to(device)
64
+ esrgan_checkpoint = load_esrgan()
65
+ esrgan_generator.load_state_dict(esrgan_checkpoint)
66
+
67
+ esrgan_generator.eval()
68
+ with torch.no_grad():
69
+ enhanced_fakes = esrgan_generator(fakes).detach().cpu()
70
+ color_match = color_histogram_mapping(enhanced_fakes, fakes.cpu())
71
+
72
+ for i in range(len(color_match)):
73
+ # denormalize and permute to correct color channel
74
+ st.image(denormalize_images(color_match[i]).permute(1, 2, 0).numpy(), width=display_width)
75
+
76
+
77
+ # default setting -> vanilla dcgan generation
78
+ if use_srgan == "No":
79
+ fakes = fakes.cpu()
80
+ st.write("Using DCGAN Model...")
81
+ for i in range(len(fakes)):
82
+ st.image(denormalize_images(fakes[i]).permute(1, 2, 0).numpy(), width=display_width)
83
+
84
+
85
+
86
+
Checkpoints/dcgan.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dc5b82eab432b52635f67ae7abf6901c36daa58ff71445ff9df01cc6b3193f2
3
+ size 14352101
Checkpoints/esrgan.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ec7aa5dd51df901367a6ee1d03c2cbbf72acadad01288040ab723860e96ffe4
3
+ size 154489349
DCGAN.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ # Generator Code
6
+
7
+ ngf = 64
8
+ num_channels = 3
9
+
10
+ class Generator(nn.Module):
11
+ def __init__(self, latent_size):
12
+ super(Generator, self).__init__()
13
+
14
+ self.latent_size = latent_size
15
+ self.conv1 = nn.ConvTranspose2d(
16
+ self.latent_size, ngf*8, 4, 1, 0, bias=False)
17
+ self.bn1 = nn.BatchNorm2d(ngf*8)
18
+ self.conv2 = nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False)
19
+ self.bn2 = nn.BatchNorm2d(ngf*4)
20
+ self.conv3 = nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False)
21
+ self.bn3 = nn.BatchNorm2d(ngf*2)
22
+ self.conv4 = nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False)
23
+ self.bn4 = nn.BatchNorm2d(ngf)
24
+
25
+ self.conv5 = nn.ConvTranspose2d(ngf, num_channels, 4, 2, 1, bias=False)
26
+
27
+ def forward(self, x):
28
+ x = F.relu(self.bn1(self.conv1(x)), inplace=True)
29
+ x = F.relu(self.bn2(self.conv2(x)), inplace=True)
30
+ x = F.relu(self.bn3(self.conv3(x)), inplace=True)
31
+ x = F.relu(self.bn4(self.conv4(x)), inplace=True)
32
+ return torch.tanh(self.conv5(x))
SRGAN.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class DenseResidualBlock(nn.Module):
5
+ """
6
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
7
+ """
8
+
9
+ def __init__(self, filters, res_scale=0.2):
10
+ super(DenseResidualBlock, self).__init__()
11
+ self.res_scale = res_scale
12
+
13
+ def block(in_features, non_linearity=True):
14
+ layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
15
+ if non_linearity:
16
+ layers += [nn.LeakyReLU()]
17
+ return nn.Sequential(*layers)
18
+
19
+ self.b1 = block(in_features=1 * filters)
20
+ self.b2 = block(in_features=2 * filters)
21
+ self.b3 = block(in_features=3 * filters)
22
+ self.b4 = block(in_features=4 * filters)
23
+ self.b5 = block(in_features=5 * filters, non_linearity=False)
24
+ self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
25
+
26
+ def forward(self, x):
27
+ inputs = x
28
+ for block in self.blocks:
29
+ out = block(inputs)
30
+ inputs = torch.cat([inputs, out], 1)
31
+ return out.mul(self.res_scale) + x
32
+
33
+
34
+ class ResidualInResidualDenseBlock(nn.Module):
35
+ def __init__(self, filters, res_scale=0.2):
36
+ super(ResidualInResidualDenseBlock, self).__init__()
37
+ self.res_scale = res_scale
38
+ self.dense_blocks = nn.Sequential(
39
+ DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
40
+ )
41
+
42
+ def forward(self, x):
43
+ return self.dense_blocks(x).mul(self.res_scale) + x
44
+
45
+
46
+ class GeneratorRRDB(nn.Module):
47
+ def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
48
+ super(GeneratorRRDB, self).__init__()
49
+
50
+ # First layer
51
+ self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
52
+ # Residual blocks
53
+ self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
54
+ # Second conv layer post residual blocks
55
+ self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
56
+ # Upsampling layers
57
+ upsample_layers = []
58
+ for _ in range(num_upsample):
59
+ upsample_layers += [
60
+ nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
61
+ nn.LeakyReLU(),
62
+ nn.PixelShuffle(upscale_factor=2),
63
+ ]
64
+ self.upsampling = nn.Sequential(*upsample_layers)
65
+ # Final output block
66
+ self.conv3 = nn.Sequential(
67
+ nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
68
+ nn.LeakyReLU(),
69
+ nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
70
+ )
71
+
72
+ def forward(self, x):
73
+ out1 = self.conv1(x)
74
+ out = self.res_blocks(out1)
75
+ out2 = self.conv2(out)
76
+ out = torch.add(out1, out2)
77
+ out = self.upsampling(out)
78
+ out = self.conv3(out)
79
+ return out
Utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import torchvision.utils as vutils
4
+ import torchvision.transforms as transforms
5
+ from skimage.exposure import match_histograms
6
+ import torch
7
+
8
+ # contains utility functions that we need in the main program
9
+
10
+ # matches the color histogram of original and the super resolution output
11
+ def color_histogram_mapping(images, references):
12
+ matched_list = []
13
+ for i in range(len(images)):
14
+ matched = match_histograms(images[i].permute(1, 2, 0).numpy(), references[i].permute(1, 2, 0).numpy(),
15
+ channel_axis=-1)
16
+ matched_list.append(matched)
17
+ return torch.tensor(np.array(matched_list)).permute(0, 3, 1, 2)
18
+
19
+
20
+ def visualize_generations(seed, images):
21
+ plt.figure(figsize=(16, 16))
22
+ plt.title(f"Seed: {seed}")
23
+ plt.axis("off")
24
+ plt.imshow(np.transpose(vutils.make_grid(images, padding=2, nrow=5, normalize=True), (2, 1, 0)))
25
+ plt.show()
26
+
27
+
28
+ # denormalize the images for proper display
29
+ def denormalize_images(images):
30
+ mean= [0.5, 0.5, 0.5]
31
+ std= [0.5, 0.5, 0.5]
32
+ inv_normalize = transforms.Normalize(
33
+ mean=[-m / s for m, s in zip(mean, std)],
34
+ std=[1 / s for s in std]
35
+ )
36
+ return inv_normalize(images)
37
+
38
+
39
+
40
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ matplotlib==3.5.2
2
+ numpy==1.23.0
3
+ torch==1.12.0
4
+ torchvision==0.13.0
5
+ scikit-image~=0.19.3
6
+ streamlit==1.11.0