|
from lib.net import NormalNet |
|
from lib.common.train_util import * |
|
import logging |
|
import torch |
|
import numpy as np |
|
from torch import nn |
|
from skimage.transform import resize |
|
import pytorch_lightning as pl |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
logging.getLogger("lightning").setLevel(logging.ERROR) |
|
|
|
|
|
class Normal(pl.LightningModule): |
|
def __init__(self, cfg): |
|
super(Normal, self).__init__() |
|
self.cfg = cfg |
|
self.batch_size = self.cfg.batch_size |
|
self.lr_N = self.cfg.lr_N |
|
|
|
self.schedulers = [] |
|
|
|
self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss()) |
|
|
|
self.in_nml = [item[0] for item in cfg.net.in_nml] |
|
|
|
def get_progress_bar_dict(self): |
|
tqdm_dict = super().get_progress_bar_dict() |
|
if "v_num" in tqdm_dict: |
|
del tqdm_dict["v_num"] |
|
return tqdm_dict |
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
weight_decay = self.cfg.weight_decay |
|
momentum = self.cfg.momentum |
|
|
|
optim_params_N_F = [ |
|
{"params": self.netG.netF.parameters(), "lr": self.lr_N}] |
|
optim_params_N_B = [ |
|
{"params": self.netG.netB.parameters(), "lr": self.lr_N}] |
|
|
|
optimizer_N_F = torch.optim.Adam( |
|
optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay |
|
) |
|
|
|
optimizer_N_B = torch.optim.Adam( |
|
optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay |
|
) |
|
|
|
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR( |
|
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma |
|
) |
|
|
|
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR( |
|
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma |
|
) |
|
|
|
self.schedulers = [scheduler_N_F, scheduler_N_B] |
|
optims = [optimizer_N_F, optimizer_N_B] |
|
|
|
return optims, self.schedulers |
|
|
|
def render_func(self, render_tensor): |
|
|
|
height = render_tensor["image"].shape[2] |
|
result_list = [] |
|
|
|
for name in render_tensor.keys(): |
|
result_list.append( |
|
resize( |
|
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose( |
|
1, 2, 0 |
|
), |
|
(height, height), |
|
anti_aliasing=True, |
|
) |
|
) |
|
result_array = np.concatenate(result_list, axis=1) |
|
|
|
return result_array |
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx): |
|
|
|
export_cfg(self.logger, self.cfg) |
|
|
|
|
|
in_tensor = {} |
|
for name in self.in_nml: |
|
in_tensor[name] = batch[name] |
|
|
|
FB_tensor = {"normal_F": batch["normal_F"], |
|
"normal_B": batch["normal_B"]} |
|
|
|
self.netG.train() |
|
|
|
preds_F, preds_B = self.netG(in_tensor) |
|
error_NF, error_NB = self.netG.get_norm_error( |
|
preds_F, preds_B, FB_tensor) |
|
|
|
(opt_nf, opt_nb) = self.optimizers() |
|
|
|
opt_nf.zero_grad() |
|
opt_nb.zero_grad() |
|
|
|
self.manual_backward(error_NF, opt_nf) |
|
self.manual_backward(error_NB, opt_nb) |
|
|
|
opt_nf.step() |
|
opt_nb.step() |
|
|
|
if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0: |
|
|
|
self.netG.eval() |
|
with torch.no_grad(): |
|
nmlF, nmlB = self.netG(in_tensor) |
|
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) |
|
result_array = self.render_func(in_tensor) |
|
|
|
self.logger.experiment.add_image( |
|
tag=f"Normal-train/{self.global_step}", |
|
img_tensor=result_array.transpose(2, 0, 1), |
|
global_step=self.global_step, |
|
) |
|
|
|
|
|
metrics_log = { |
|
"train_loss-NF": error_NF.item(), |
|
"train_loss-NB": error_NB.item(), |
|
} |
|
|
|
tf_log = tf_log_convert(metrics_log) |
|
bar_log = bar_log_convert(metrics_log) |
|
|
|
return { |
|
"loss": error_NF + error_NB, |
|
"loss-NF": error_NF, |
|
"loss-NB": error_NB, |
|
"log": tf_log, |
|
"progress_bar": bar_log, |
|
} |
|
|
|
def training_epoch_end(self, outputs): |
|
|
|
if [] in outputs: |
|
outputs = outputs[0] |
|
|
|
|
|
metrics_log = { |
|
"train_avgloss": batch_mean(outputs, "loss"), |
|
"train_avgloss-NF": batch_mean(outputs, "loss-NF"), |
|
"train_avgloss-NB": batch_mean(outputs, "loss-NB"), |
|
} |
|
|
|
tf_log = tf_log_convert(metrics_log) |
|
|
|
tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0] |
|
tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0] |
|
|
|
return {"log": tf_log} |
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
|
|
in_tensor = {} |
|
for name in self.in_nml: |
|
in_tensor[name] = batch[name] |
|
|
|
FB_tensor = {"normal_F": batch["normal_F"], |
|
"normal_B": batch["normal_B"]} |
|
|
|
self.netG.train() |
|
|
|
preds_F, preds_B = self.netG(in_tensor) |
|
error_NF, error_NB = self.netG.get_norm_error( |
|
preds_F, preds_B, FB_tensor) |
|
|
|
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or ( |
|
batch_idx == 0 |
|
): |
|
|
|
with torch.no_grad(): |
|
nmlF, nmlB = self.netG(in_tensor) |
|
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) |
|
result_array = self.render_func(in_tensor) |
|
|
|
self.logger.experiment.add_image( |
|
tag=f"Normal-val/{self.global_step}", |
|
img_tensor=result_array.transpose(2, 0, 1), |
|
global_step=self.global_step, |
|
) |
|
|
|
return { |
|
"val_loss": error_NF + error_NB, |
|
"val_loss-NF": error_NF, |
|
"val_loss-NB": error_NB, |
|
} |
|
|
|
def validation_epoch_end(self, outputs): |
|
|
|
|
|
metrics_log = { |
|
"val_avgloss": batch_mean(outputs, "val_loss"), |
|
"val_avgloss-NF": batch_mean(outputs, "val_loss-NF"), |
|
"val_avgloss-NB": batch_mean(outputs, "val_loss-NB"), |
|
} |
|
|
|
tf_log = tf_log_convert(metrics_log) |
|
|
|
return {"log": tf_log} |
|
|