import os import math from decimal import Decimal import utility import torch import torch.nn.utils as utils from tqdm import tqdm import torch.cuda.amp as amp from torch.utils.tensorboard import SummaryWriter import torchvision import numpy as np class Trainer(): def __init__(self, args, loader, my_model, my_loss, ckp): self.args = args self.scale = args.scale self.ckp = ckp self.loader_train = loader.loader_train self.loader_test = loader.loader_test self.model = my_model self.loss = my_loss self.optimizer = utility.make_optimizer(args, self.model) if self.args.load != '': self.optimizer.load(ckp.dir, epoch=len(ckp.log)) self.error_last = 1e8 self.scaler=amp.GradScaler( enabled=args.amp ) self.writter=None self.recurrence=args.recurrence if args.recurrence>1: self.writter=SummaryWriter(f"runs/{args.save}") def train(self): self.loss.step() epoch = self.optimizer.get_last_epoch() + 1 lr = self.optimizer.get_lr() self.ckp.write_log( '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) ) self.loss.start_log() self.model.train() timer_data, timer_model = utility.timer(), utility.timer() # TEMP self.loader_train.dataset.set_scale(0) total=len(self.loader_train) buffer=[0.0]*self.recurrence # torch.autograd.set_detect_anomaly(True) for batch, (lr, hr, _,) in enumerate(self.loader_train): lr, hr = self.prepare(lr, hr) # print(lr.min(),lr.max(), hr.min(),hr.max()) # exit(0) timer_data.hold() timer_model.tic() self.optimizer.zero_grad() with amp.autocast(self.args.amp): sr = self.model(lr, 0) if len(sr)==1 and isinstance(sr,list): sr=sr[0] # loss,buffer_lst=sequence_loss(sr,hr) loss = self.loss(sr, hr) self.scaler.scale(loss).backward() if self.args.gclip > 0: self.scaler.unscale_(self.optimizer) utils.clip_grad_value_( self.model.parameters(), self.args.gclip ) self.scaler.step(self.optimizer) self.scaler.update() for i in range(self.recurrence): buffer[i]+=self.loss.buffer[i] # self.optimizer.step() timer_model.hold() if (batch + 1) % self.args.print_every == 0: self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), self.loss.display_loss(batch), timer_model.release(), timer_data.release())) timer_data.tic() if self.writter: for i in range(self.recurrence): grid=torchvision.utils.make_grid(sr[i]) self.writter.add_image(f"Output{i}",grid,epoch) self.writter.add_scalar(f"Loss{i}",buffer[i]/total,epoch) self.writter.add_image("Input",torchvision.utils.make_grid(lr),epoch) self.writter.add_image("Target",torchvision.utils.make_grid(hr),epoch) self.loss.end_log(len(self.loader_train)) self.error_last = self.loss.log[-1, -1] self.optimizer.schedule() def test(self): torch.set_grad_enabled(False) epoch = self.optimizer.get_last_epoch() self.ckp.write_log('\nEvaluation:') self.ckp.add_log( torch.zeros(1, len(self.loader_test), len(self.scale)) ) self.model.eval() timer_test = utility.timer() if self.args.save_results: self.ckp.begin_background() for idx_data, d in enumerate(self.loader_test): for idx_scale, scale in enumerate(self.scale): d.dataset.set_scale(idx_scale) for lr, hr, filename in tqdm(d, ncols=80): lr, hr = self.prepare(lr, hr) with amp.autocast(self.args.amp): sr = self.model(lr, idx_scale) if isinstance(sr,list): sr=sr[-1] sr = utility.quantize(sr, self.args.rgb_range) save_list = [sr] self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( sr, hr, scale, self.args.rgb_range, dataset=d ) if self.args.save_gt: save_list.extend([lr, hr]) if self.args.save_results: self.ckp.save_results(d, filename[0], save_list, scale) self.ckp.log[-1, idx_data, idx_scale] /= len(d) best = self.ckp.log.max(0) self.ckp.write_log( '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( d.dataset.name, scale, self.ckp.log[-1, idx_data, idx_scale], best[0][idx_data, idx_scale], best[1][idx_data, idx_scale] + 1 ) ) self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) self.ckp.write_log('Saving...') # torch.cuda.empty_cache() if self.args.save_results: self.ckp.end_background() if not self.args.test_only: self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) self.ckp.write_log( 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True ) torch.set_grad_enabled(True) def prepare(self, *args): device = torch.device('cpu' if self.args.cpu else 'cuda') def _prepare(tensor): if self.args.precision == 'half': tensor = tensor.half() return tensor.to(device) return [_prepare(a) for a in args] def terminate(self): if self.args.test_only: self.test() return True else: epoch = self.optimizer.get_last_epoch() + 1 return epoch >= self.args.epochs