import os import glob import time import numpy as np from PIL import Image from pathlib import Path from tqdm.notebook import tqdm import matplotlib.pyplot as plt from skimage.color import rgb2lab, lab2rgb import torch from torch import nn, optim from torchvision import transforms from torchvision.utils import make_grid from torch.utils.data import Dataset, DataLoader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class UnetBlock(nn.Module): def __init__( self, nf, ni, submodule=None, input_c=None, dropout=False, innermost=False, outermost=False, ): super().__init__() self.outermost = outermost if input_c is None: input_c = nf downconv = nn.Conv2d( input_c, ni, kernel_size=4, stride=2, padding=1, bias=False ) downrelu = nn.LeakyReLU(0.2, True) downnorm = nn.BatchNorm2d(ni) uprelu = nn.ReLU(True) upnorm = nn.BatchNorm2d(nf) if outermost: upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d( ni, nf, kernel_size=4, stride=2, padding=1, bias=False ) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d( ni * 2, nf, kernel_size=4, stride=2, padding=1, bias=False ) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if dropout: up += [nn.Dropout(0.5)] model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: return torch.cat([x, self.model(x)], 1) class Unet(nn.Module): def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64): super().__init__() unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True) for _ in range(n_down - 5): unet_block = UnetBlock( num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True ) out_filters = num_filters * 8 for _ in range(3): unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block) out_filters //= 2 self.model = UnetBlock( output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True ) def forward(self, x): return self.model(x)