Spaces:
Sleeping
Sleeping
File size: 6,140 Bytes
8eec341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from inference import init_model, GANLoss
class UnetBlock(nn.Module):
def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
innermost=False, outermost=False):
super().__init__()
self.outermost = outermost
if input_c is None: input_c = nf
downconv = nn.Conv2d(input_c, ni, kernel_size=4,
stride=2, padding=1, bias=False)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = nn.BatchNorm2d(ni)
uprelu = nn.ReLU(True)
upnorm = nn.BatchNorm2d(nf)
if outermost:
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
stride=2, padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
stride=2, padding=1, bias=False)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
stride=2, padding=1, bias=False)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if dropout: up += [nn.Dropout(0.5)]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
class Unet(nn.Module):
def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
super().__init__()
unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
for _ in range(n_down - 5):
unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
out_filters = num_filters * 8
for _ in range(3):
unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
out_filters //= 2
self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
def forward(self, x):
return self.model(x)
class PatchDiscriminator(nn.Module):
def __init__(self, input_c, num_filters=64, n_down=3):
super().__init__()
model = [self.get_layers(input_c, num_filters, norm=False)]
model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
for i in range(n_down)] # the 'if' statement is taking care of not using
# stride of 2 for the last block in this loop
model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
# activation for the last layer of the model
self.model = nn.Sequential(*model)
def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
if norm: layers += [nn.BatchNorm2d(nf)]
if act: layers += [nn.LeakyReLU(0.2, True)]
return nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class MainModel(nn.Module):
def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
beta1=0.5, beta2=0.999, lambda_L1=100.):
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.lambda_L1 = lambda_L1
if net_G is None:
self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
else:
self.net_G = net_G.to(self.device)
self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
self.L1criterion = nn.L1Loss()
self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
def set_requires_grad(self, model, requires_grad=True):
for p in model.parameters():
p.requires_grad = requires_grad
def setup_input(self, data):
self.L = data['L'].to(self.device)
self.ab = data['ab'].to(self.device)
def forward(self):
self.fake_color = self.net_G(self.L)
def backward_D(self):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image.detach())
self.loss_D_fake = self.GANcriterion(fake_preds, False)
real_image = torch.cat([self.L, self.ab], dim=1)
real_preds = self.net_D(real_image)
self.loss_D_real = self.GANcriterion(real_preds, True)
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
def backward_G(self):
fake_image = torch.cat([self.L, self.fake_color], dim=1)
fake_preds = self.net_D(fake_image)
self.loss_G_GAN = self.GANcriterion(fake_preds, True)
self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()
def optimize(self):
self.forward()
self.net_D.train()
self.set_requires_grad(self.net_D, True)
self.opt_D.zero_grad()
self.backward_D()
self.opt_D.step()
self.net_G.train()
self.set_requires_grad(self.net_D, False)
self.opt_G.zero_grad()
self.backward_G()
self.opt_G.step() |