Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
from torch.nn import Conv2d,LeakyReLU,BatchNorm2d, ConvTranspose2d,ReLU | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
import numpy as np | |
import cv2 | |
import os | |
from accelerate import Accelerator | |
import datetime | |
import torch | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from tensorboardX import SummaryWriter | |
from accelerate import notebook_launcher | |
accelerator = Accelerator() | |
device = accelerator.device | |
def get_transforms(): | |
transform = transforms.Compose([ | |
transforms.ToTensor(),# H,W,C -> C,H,W && [0,255] -> [0,1] | |
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) #[0,1] -> [-1,1] | |
]) | |
return transform | |
class MYDataSet(Dataset): | |
def __init__(self,src_data_path,dst_data_path): | |
self.train_A_imglist = self.get_imglist(src_data_path) | |
self.train_B_imglist = self.get_imglist(dst_data_path) | |
self.transform = get_transforms() | |
def get_imglist(self, img_dir): | |
img_name_list = sorted(os.listdir(img_dir)) | |
img_list = [] | |
half_len = len(img_name_list) // 3 | |
for img_name in img_name_list[:half_len]: | |
img_path = os.path.join(img_dir, img_name) | |
img_list.append(img_path) | |
return img_list | |
def __len__(self): | |
return len(self.train_A_imglist) | |
def __getitem__(self,index): | |
train_A_img_path = self.train_A_imglist[index] | |
train_B_img_path = self.train_B_imglist[index] | |
train_A_img = cv2.imread(train_A_img_path) | |
train_B_img = cv2.imread(train_B_img_path) | |
train_A_tensor = self.transform(train_A_img) | |
train_B_tensor = self.transform(train_B_img) | |
return [train_A_tensor,train_B_tensor] | |
def encoder_layer(in_channels,out_channels,kernel_size=4,stride = 2,padding = 1): # NOTE: Padding here is different from the 'vaild' in tensorflow version of original github | |
layer = nn.Sequential( | |
Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding), | |
BatchNorm2d(out_channels), | |
LeakyReLU(0.2) | |
) | |
return layer | |
def decoder_layer(in_channels,out_channels,last_layer=False,kernel_size=4,stride = 2,padding = 1): | |
if not last_layer: | |
layer = nn.Sequential( | |
ConvTranspose2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding), | |
torch.nn.ReLU() | |
) | |
else: | |
layer = nn.Sequential( | |
ConvTranspose2d(in_channels,out_channels,kernel_size=4,stride=2,padding=1), | |
torch.nn.Tanh() | |
) | |
return layer | |
def discrimiter_layer(in_channels,out_channels,kernel_size=4,stride = 2,padding = 1,wgan=False): | |
if wgan: | |
layer = nn.Sequential( | |
Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding), | |
BatchNorm2d(out_channels), | |
LeakyReLU(0.2) | |
) | |
else: | |
layer = nn.Sequential( | |
Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding), | |
LeakyReLU(0.2) | |
) | |
return layer | |
class GeneratorNet(torch.nn.Module): | |
def __init__(self): | |
super(GeneratorNet, self).__init__() | |
# Encoder | |
self.enc_conv1 = encoder_layer(3,64) | |
self.enc_conv2 = encoder_layer(64,128) | |
self.enc_conv3 = encoder_layer(128,256) | |
self.enc_conv4 = encoder_layer(256,512) | |
self.enc_conv5 = encoder_layer(512,512) | |
self.enc_conv6 = encoder_layer(512,512) | |
self.enc_conv7 = encoder_layer(512,512) | |
self.enc_conv8 = encoder_layer(512,512,padding=1) | |
# Decoder | |
self.dec_conv1 = decoder_layer(512,512) | |
self.dec_conv2 = decoder_layer(1024,512) | |
self.dec_conv3 = decoder_layer(1024,512) | |
self.dec_conv4 = decoder_layer(1024,512) | |
self.dec_conv5 = decoder_layer(1024,256) | |
self.dec_conv6 = decoder_layer(512,128) | |
self.dec_conv7 = decoder_layer(256,64) | |
self.dec_conv8 = decoder_layer(128,3,last_layer=True) | |
def forward(self,input_x): | |
# Encoder | |
output_enc_conv1 = self.enc_conv1(input_x) | |
output_enc_conv2 = self.enc_conv2(output_enc_conv1) | |
output_enc_conv3 = self.enc_conv3(output_enc_conv2) | |
output_enc_conv4 = self.enc_conv4(output_enc_conv3) | |
output_enc_conv5 = self.enc_conv5(output_enc_conv4) | |
output_enc_conv6 = self.enc_conv6(output_enc_conv5) | |
output_enc_conv7 = self.enc_conv7(output_enc_conv6) | |
output_enc_conv8 = self.enc_conv8(output_enc_conv7) | |
# Decoder | |
output_dec_conv1 = self.dec_conv1(output_enc_conv8) | |
output_dec_conv1 = torch.cat([output_dec_conv1,output_enc_conv7],dim = 1) | |
output_dec_conv2 = self.dec_conv2(output_dec_conv1) | |
output_dec_conv2 = torch.cat([output_dec_conv2,output_enc_conv6],dim = 1) | |
output_dec_conv3 = self.dec_conv3(output_dec_conv2) | |
output_dec_conv3 = torch.cat([output_dec_conv3,output_enc_conv5],dim = 1) | |
output_dec_conv4 = self.dec_conv4(output_dec_conv3) | |
output_dec_conv4 = torch.cat([output_dec_conv4,output_enc_conv4],dim = 1) | |
output_dec_conv5 = self.dec_conv5(output_dec_conv4) | |
output_dec_conv5 = torch.cat([output_dec_conv5,output_enc_conv3],dim = 1) | |
output_dec_conv6 = self.dec_conv6(output_dec_conv5) | |
output_dec_conv6 = torch.cat([output_dec_conv6,output_enc_conv2],dim = 1) | |
output_dec_conv7 = self.dec_conv7(output_dec_conv6) | |
output_dec_conv7 = torch.cat([output_dec_conv7,output_enc_conv1],dim = 1) | |
output_dec_conv8 = self.dec_conv8(output_dec_conv7) | |
return output_dec_conv8 | |
class DiscrimiterNet(torch.nn.Module): | |
def __init__(self,wgan_loss): | |
super(DiscrimiterNet, self).__init__() | |
self.wgan_loss = wgan_loss | |
self.conv1 = discrimiter_layer(3,64,self.wgan_loss) | |
self.conv2 = discrimiter_layer(64,128,self.wgan_loss) | |
self.conv3 = discrimiter_layer(128,256,self.wgan_loss) | |
self.conv4 = discrimiter_layer(256,512,self.wgan_loss) | |
self.conv5 = discrimiter_layer(512,1,kernel_size=1,stride=1) | |
def forward(self,x): | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
x = self.conv4(x) | |
x = self.conv5(x) | |
return x | |
def tensor2img(one_tensor):# [b,c,h,w] [-1,1] | |
tensor = one_tensor.squeeze(0) #[c,h,w] [0,1] | |
tensor = (tensor*0.5 + 0.5)*255 # [c,h,w] [0,255] | |
tensor_cpu = tensor.cpu() | |
img = np.array(tensor_cpu,dtype=np.uint8) | |
img = np.transpose(img,(1,2,0)) | |
return img | |
def img2tensor(np_img):# [h,w,c] | |
tensor = get_transforms()(np_img).cuda() # [c,h,w] [-1,1] | |
tensor = tensor.unsqueeze(0) # [b,c,h,w] [-1,1] | |
return tensor | |
def weights_init(module): | |
classname = module.__class__.__name__ | |
if classname.find('Conv') !=-1: | |
nn.init.normal_(module.weight.data,0.0,0.02) | |
def loss_gradient_difference(real_image,generated): # b x c x h x w | |
true_x_shifted_right = real_image[:,:,1:,:]# 32 x 3 x 255 x 256 | |
true_x_shifted_left = real_image[:,:,:-1,:] | |
true_x_gradient = torch.abs(true_x_shifted_left - true_x_shifted_right) | |
generated_x_shift_right = generated[:,:,1:,:]# 32 x 3 x 255 x 256 | |
generated_x_shift_left = generated[:,:,:-1,:] | |
generated_x_griednt = torch.abs(generated_x_shift_left - generated_x_shift_right) | |
difference_x = true_x_gradient - generated_x_griednt | |
loss_x_gradient = (torch.sum(difference_x**2))/2 # tf.nn.l2_loss(true_x_gradient - generated_x_gradient) | |
true_y_shifted_right = real_image[:,:,:,1:] | |
true_y_shifted_left = real_image[:,:,:,:-1] | |
true_y_gradient = torch.abs(true_y_shifted_left - true_y_shifted_right) | |
generated_y_shift_right = generated[:,:,:,1:] | |
generated_y_shift_left = generated[:,:,:,:-1] | |
generated_y_griednt = torch.abs(generated_y_shift_left - generated_y_shift_right) | |
difference_y = true_y_gradient - generated_y_griednt | |
loss_y_gradient = (torch.sum(difference_y**2))/2 # tf.nn.l2_loss(true_y_gradient - generated_y_gradient) | |
igdl = loss_x_gradient + loss_y_gradient | |
return igdl | |
def calculate_x_gradient(images): | |
x_gradient_filter = torch.Tensor( | |
[ | |
[[0, 0, 0], [-1, 0, 1], [0, 0, 0]], | |
[[0, 0, 0], [-1, 0, 1], [0, 0, 0]], | |
[[0, 0, 0], [-1, 0, 1], [0, 0, 0]], | |
] | |
).cuda() | |
x_gradient_filter = x_gradient_filter.view(3, 1, 3, 3) | |
result = torch.functional.F.conv2d( | |
images, x_gradient_filter, groups=3, padding=(1, 1) | |
) | |
return result | |
def calculate_y_gradient(images): | |
y_gradient_filter = torch.Tensor( | |
[ | |
[[0, 1, 0], [0, 0, 0], [0, -1, 0]], | |
[[0, 1, 0], [0, 0, 0], [0, -1, 0]], | |
[[0, 1, 0], [0, 0, 0], [0, -1, 0]], | |
] | |
).cuda() | |
y_gradient_filter = y_gradient_filter.view(3, 1, 3, 3) | |
result = torch.functional.F.conv2d( | |
images, y_gradient_filter, groups=3, padding=(1, 1) | |
) | |
return result | |
def loss_igdl( correct_images, generated_images): # taken from https://github.com/Arquestro/ugan-pytorch/blob/master/ops/loss_modules.py | |
correct_images_gradient_x = calculate_x_gradient(correct_images) | |
generated_images_gradient_x = calculate_x_gradient(generated_images) | |
correct_images_gradient_y = calculate_y_gradient(correct_images) | |
generated_images_gradient_y = calculate_y_gradient(generated_images) | |
pairwise_p_distance = torch.nn.PairwiseDistance(p=1) | |
distances_x_gradient = pairwise_p_distance( | |
correct_images_gradient_x, generated_images_gradient_x | |
) | |
distances_y_gradient = pairwise_p_distance( | |
correct_images_gradient_y, generated_images_gradient_y | |
) | |
loss_x_gradient = torch.mean(distances_x_gradient) | |
loss_y_gradient = torch.mean(distances_y_gradient) | |
loss = 0.5 * (loss_x_gradient + loss_y_gradient) | |
return loss | |
def ToTensor(image): | |
"""Convert ndarrays in sample to Tensors.""" | |
# numpy image: H x W x C | |
# torch image: C X H X W | |
image = image.transpose((2, 0, 1)) | |
# Normalize image from [0, 255] to [0, 1] | |
image = 1 / 255.0 * image | |
return torch.from_numpy(image).type(dtype=torch.float) | |
def train(): | |
accelerator = Accelerator() | |
device = accelerator.device | |
print(device) | |
# Define the argparse arguments as variables | |
trainA_path = '/kaggle/input/underwater-img/underwater_imagenet/trainA' | |
trainB_path = '/kaggle/input/underwater-img/underwater_imagenet/trainB' | |
use_wgan = True | |
lr = 1e-4 | |
max_epoch = 350 | |
bz = 32 | |
lbda1 = 100 | |
lbda2 = 1 | |
num_workers = 4 | |
checkpoints_root = 'checkpoints' | |
log_root = './log' | |
if __name__ == "__main__": | |
wgan = use_wgan | |
learning_rate = lr | |
max_epochs = max_epoch | |
batch_size = bz | |
lambda_1 = lbda1 | |
lambda_2 = lbda2 | |
netG = torch.nn.DataParallel(GeneratorNet()).to(device) | |
netD = torch.nn.DataParallel(DiscrimiterNet(wgan_loss=wgan)).to(device) | |
optimizer_g = optim.Adam(netG.parameters(), lr=learning_rate) | |
optimizer_d = optim.Adam(netD.parameters(), lr=learning_rate) | |
dataset = MYDataSet(src_data_path=trainA_path, dst_data_path=trainB_path) | |
datasetloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) | |
log_root = log_root | |
date = datetime.datetime.now().strftime('%F_%T').replace(':', '_') | |
log_folder = date | |
log_dir = os.path.join(log_root, log_folder) | |
if not os.path.exists(log_dir): | |
os.makedirs(log_dir) | |
writer = SummaryWriter(log_dir=log_dir) | |
checkpoint_root = checkpoints_root | |
checkpoint_folder = date | |
checkpoint_dir = os.path.join(checkpoint_root, checkpoint_folder) | |
if not os.path.exists(checkpoint_dir): | |
os.makedirs(checkpoint_dir) | |
netG, netD, optimizer_g, optimizer_d, datasetloader = accelerator.prepare(netG, netD, optimizer_g, optimizer_d, datasetloader) | |
for epoch in range(0, max_epochs): | |
print("epoch :",epoch+1) | |
d_loss_log_list = [] | |
g_loss_log_list = [] | |
for iteration, data in enumerate(tqdm(datasetloader)): | |
batchtensor_A = data[0].to(device) | |
batchtensor_B = data[1].to(device) | |
generated_batchtensor = netG(batchtensor_A) | |
num_critic = 1 if not wgan else 5 | |
for i in range(num_critic): | |
optimizer_d.zero_grad() | |
d_fake = netD(generated_batchtensor.detach()) | |
d_real = netD(batchtensor_B) | |
d_loss = torch.mean(d_fake) - torch.mean(d_real) | |
if wgan: | |
lambda_gp = 10 | |
epsilon = torch.rand(batchtensor_B.size()[0], 1, 1, 1).to(device) | |
x_hat = batchtensor_B * epsilon + (1 - epsilon) * generated_batchtensor | |
d_hat = netD(x_hat) | |
gradients = torch.autograd.grad(outputs=d_hat, inputs=x_hat, grad_outputs=torch.ones_like(d_hat), create_graph=True, retain_graph=True)[0] | |
gradients = gradients.view(gradients.size(0), -1) | |
gradient_penalty = lambda_gp * ((gradients.norm(2, dim=1) - 1) ** 2).mean() | |
d_loss += gradient_penalty | |
accelerator.backward(d_loss,retain_graph=True) | |
optimizer_d.step() | |
d_loss_log = d_loss.item() | |
d_loss_log_list.append(d_loss_log) | |
optimizer_g.zero_grad() | |
d_fake = netD(generated_batchtensor) | |
g_loss = -torch.mean(d_fake) | |
l1_loss = torch.mean(torch.abs(generated_batchtensor - batchtensor_B)) | |
igdl_loss = loss_igdl(batchtensor_B, generated_batchtensor) | |
g_loss += lambda_1 * l1_loss + lambda_2 * igdl_loss | |
accelerator.backward(g_loss) | |
optimizer_g.step() | |
g_loss_log = g_loss.item() | |
g_loss_log_list.append(g_loss_log) | |
writer.add_scalar('G_loss', g_loss_log, (epoch * len(datasetloader) + iteration)) | |
writer.add_scalar('D_loss', d_loss_log, (epoch * len(datasetloader) + iteration)) | |
d_loss_average_log = np.array(d_loss_log_list).mean() | |
g_loss_average_log = np.array(g_loss_log_list).mean() | |
writer.add_scalar('D_loss_epoch', d_loss_average_log, epoch) | |
writer.add_scalar('G_loss_epoch', g_loss_average_log, epoch) | |
print('D_loss_epoch',d_loss_average_log) | |
print('g_loss_epoch',g_loss_average_log) | |
if (epoch + 1) % 50 == 0: | |
torch.save({ | |
'epoch': epoch, | |
'netG_state_dict': netG.state_dict(), | |
'netD_state_dict': netD.state_dict(), | |
'optimizer_g_state_dict': optimizer_g.state_dict(), | |
'optimizer_d_state_dict': optimizer_d.state_dict(), | |
'g_loss': g_loss_log_list[-1], | |
'd_loss': d_loss_log_list[-1], | |
}, f"checkpoint_epoch_{epoch + 1}.pth") | |
writer.close() | |
if __name__ == "__main__": | |
notebook_launcher(train, num_processes=1) | |