|
import os
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
|
|
import data.common
|
|
from utils import interact, MultiSaver
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torchvision
|
|
|
|
import torch.cuda.amp as amp
|
|
|
|
class Trainer():
|
|
|
|
def __init__(self, args, model, criterion, optimizer, loaders):
|
|
print('===> Initializing trainer')
|
|
self.args = args
|
|
self.mode = 'train'
|
|
self.epoch = args.start_epoch
|
|
self.save_dir = args.save_dir
|
|
|
|
self.model = model
|
|
self.criterion = criterion
|
|
self.optimizer = optimizer
|
|
self.loaders = loaders
|
|
|
|
|
|
self.do_train = args.do_train
|
|
self.do_validate = args.do_validate
|
|
self.do_test = args.do_test
|
|
|
|
self.device = args.device
|
|
self.dtype = args.dtype
|
|
self.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
|
|
self.recurrence=args.n_scales
|
|
if self.args.demo and self.args.demo_output_dir:
|
|
self.result_dir = self.args.demo_output_dir
|
|
else:
|
|
self.result_dir = os.path.join(self.save_dir, 'result')
|
|
os.makedirs(self.result_dir, exist_ok=True)
|
|
print('results are saved in {}'.format(self.result_dir))
|
|
|
|
self.imsaver = MultiSaver(self.result_dir)
|
|
|
|
self.is_slave = self.args.launched and self.args.rank != 0
|
|
|
|
self.scaler = amp.GradScaler(
|
|
init_scale=self.args.init_scale,
|
|
enabled=self.args.amp
|
|
)
|
|
if not self.is_slave:
|
|
self.writter=SummaryWriter(f"runs/{args.save_dir}")
|
|
|
|
|
|
def save(self, epoch=None):
|
|
epoch = self.epoch if epoch is None else epoch
|
|
if epoch % self.args.save_every == 0:
|
|
if self.mode == 'train':
|
|
self.model.save(epoch)
|
|
self.optimizer.save(epoch)
|
|
self.criterion.save()
|
|
|
|
return
|
|
|
|
def load(self, epoch=None, pretrained=None):
|
|
if epoch is None:
|
|
epoch = self.args.load_epoch
|
|
self.epoch = epoch
|
|
self.model.load(epoch, pretrained)
|
|
self.optimizer.load(epoch)
|
|
self.criterion.load(epoch)
|
|
|
|
return
|
|
|
|
def train(self, epoch):
|
|
self.mode = 'train'
|
|
self.epoch = epoch
|
|
|
|
self.model.train()
|
|
self.model.to(dtype=self.dtype)
|
|
|
|
self.criterion.train()
|
|
self.criterion.epoch = epoch
|
|
|
|
if not self.is_slave:
|
|
print('[Epoch {} / lr {:.2e}]'.format(
|
|
epoch, self.optimizer.get_lr()
|
|
))
|
|
total=len(self.loaders[self.mode])
|
|
acc=0.0
|
|
if self.args.distributed:
|
|
self.loaders[self.mode].sampler.set_epoch(epoch)
|
|
if self.is_slave:
|
|
tq = self.loaders[self.mode]
|
|
else:
|
|
tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
|
|
buffer=[0.0]*self.recurrence
|
|
torch.set_grad_enabled(True)
|
|
for idx, batch in enumerate(tq):
|
|
self.optimizer.zero_grad()
|
|
|
|
input, target = data.common.to(
|
|
batch[0], batch[1], device=self.device, dtype=self.dtype)
|
|
|
|
|
|
with amp.autocast(self.args.amp):
|
|
output = self.model(input)
|
|
loss = self.criterion(output, target)
|
|
|
|
for i in range(self.recurrence):
|
|
buffer[i]+=self.criterion.buffer[i]
|
|
|
|
self.scaler.scale(loss).backward()
|
|
if self.args.clip>0:
|
|
self.scaler.unscale_(self.optimizer.G)
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
|
|
self.scaler.step(self.optimizer.G)
|
|
self.scaler.update()
|
|
|
|
if isinstance(tq, tqdm):
|
|
tq.set_description(self.criterion.get_loss_desc())
|
|
if not self.is_slave:
|
|
rgb_range=self.args.rgb_range
|
|
for i in range(len(output)):
|
|
grid=torchvision.utils.make_grid(output[i])
|
|
self.writter.add_image(f"Output{i}",grid/rgb_range,epoch)
|
|
self.writter.add_scalar(f"Loss{i}",buffer[i],epoch)
|
|
self.writter.add_image("Input",torchvision.utils.make_grid(input[0])/rgb_range,epoch)
|
|
self.writter.add_image("Target",torchvision.utils.make_grid(target[0])/rgb_range,epoch)
|
|
self.criterion.normalize()
|
|
if isinstance(tq, tqdm):
|
|
tq.set_description(self.criterion.get_loss_desc())
|
|
tq.display(pos=-1)
|
|
|
|
self.criterion.step()
|
|
self.optimizer.schedule(self.criterion.get_last_loss())
|
|
|
|
if self.args.rank == 0:
|
|
self.save(epoch)
|
|
|
|
return
|
|
|
|
def evaluate(self, epoch, mode='val'):
|
|
self.mode = mode
|
|
self.epoch = epoch
|
|
|
|
self.model.eval()
|
|
self.model.to(dtype=self.dtype_eval)
|
|
|
|
if mode == 'val':
|
|
self.criterion.validate()
|
|
elif mode == 'test':
|
|
self.criterion.test()
|
|
self.criterion.epoch = epoch
|
|
|
|
self.imsaver.join_background()
|
|
|
|
if self.is_slave:
|
|
tq = self.loaders[self.mode]
|
|
else:
|
|
tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
|
|
|
|
compute_loss = True
|
|
torch.set_grad_enabled(False)
|
|
for idx, batch in enumerate(tq):
|
|
input, target = data.common.to(
|
|
batch[0], batch[1], device=self.device, dtype=self.dtype_eval)
|
|
with amp.autocast(self.args.amp):
|
|
output = self.model(input)
|
|
|
|
|
|
|
|
|
|
|
|
if mode == 'demo':
|
|
pad_width = batch[2]
|
|
output[0], _ = data.common.pad(output[0], pad_width=pad_width, negative=True)
|
|
|
|
if isinstance(batch[1], torch.BoolTensor):
|
|
compute_loss = False
|
|
|
|
if compute_loss:
|
|
self.criterion(output, target)
|
|
if isinstance(tq, tqdm):
|
|
tq.set_description(self.criterion.get_loss_desc())
|
|
|
|
if self.args.save_results != 'none':
|
|
if isinstance(output, (list, tuple)):
|
|
result = output[-1]
|
|
elif isinstance(output, torch.Tensor):
|
|
result = output
|
|
|
|
names = batch[-1]
|
|
|
|
if self.args.save_results == 'part' and compute_loss:
|
|
indices = batch[-2]
|
|
save_ids = [save_id for save_id, idx in enumerate(indices) if idx % 10 == 0]
|
|
|
|
result = result[save_ids]
|
|
names = [names[save_id] for save_id in save_ids]
|
|
|
|
self.imsaver.save_image(result, names)
|
|
|
|
if compute_loss:
|
|
self.criterion.normalize()
|
|
if isinstance(tq, tqdm):
|
|
tq.set_description(self.criterion.get_loss_desc())
|
|
tq.display(pos=-1)
|
|
|
|
self.criterion.step()
|
|
if self.args.rank == 0:
|
|
self.save()
|
|
|
|
self.imsaver.end_background()
|
|
|
|
def validate(self, epoch):
|
|
self.evaluate(epoch, 'val')
|
|
return
|
|
|
|
def test(self, epoch):
|
|
self.evaluate(epoch, 'test')
|
|
return
|
|
|
|
def fill_evaluation(self, epoch, mode=None, force=False):
|
|
if epoch <= 0:
|
|
return
|
|
|
|
if mode is not None:
|
|
self.mode = mode
|
|
|
|
do_eval = force
|
|
if not force:
|
|
loss_missing = epoch not in self.criterion.loss_stat[self.mode]['Total']
|
|
|
|
metric_missing = False
|
|
for metric_type in self.criterion.metric:
|
|
if epoch not in self.criterion.metric_stat[mode][metric_type]:
|
|
metric_missing = True
|
|
|
|
do_eval = loss_missing or metric_missing
|
|
|
|
if do_eval:
|
|
try:
|
|
self.load(epoch)
|
|
self.evaluate(epoch, self.mode)
|
|
except:
|
|
|
|
pass
|
|
|
|
return
|
|
|