hyliu's picture
Upload folder using huggingface_hub
8ec10cf verified
import os
from tqdm import tqdm
import torch
import data.common
from utils import interact, MultiSaver
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torch.cuda.amp as amp
class Trainer():
def __init__(self, args, model, criterion, optimizer, loaders):
print('===> Initializing trainer')
self.args = args
self.mode = 'train' # 'val', 'test'
self.epoch = args.start_epoch
self.save_dir = args.save_dir
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.loaders = loaders
self.do_train = args.do_train
self.do_validate = args.do_validate
self.do_test = args.do_test
self.device = args.device
self.dtype = args.dtype
self.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
self.recurrence=args.n_scales
if self.args.demo and self.args.demo_output_dir:
self.result_dir = self.args.demo_output_dir
else:
self.result_dir = os.path.join(self.save_dir, 'result')
os.makedirs(self.result_dir, exist_ok=True)
print('results are saved in {}'.format(self.result_dir))
self.imsaver = MultiSaver(self.result_dir)
self.is_slave = self.args.launched and self.args.rank != 0
self.scaler = amp.GradScaler(
init_scale=self.args.init_scale,
enabled=self.args.amp
)
if not self.is_slave:
self.writter=SummaryWriter(f"runs/{args.save_dir}")
def save(self, epoch=None):
epoch = self.epoch if epoch is None else epoch
if epoch % self.args.save_every == 0:
if self.mode == 'train':
self.model.save(epoch)
self.optimizer.save(epoch)
self.criterion.save()
return
def load(self, epoch=None, pretrained=None):
if epoch is None:
epoch = self.args.load_epoch
self.epoch = epoch
self.model.load(epoch, pretrained)
self.optimizer.load(epoch)
self.criterion.load(epoch)
return
def train(self, epoch):
self.mode = 'train'
self.epoch = epoch
self.model.train()
self.model.to(dtype=self.dtype)
self.criterion.train()
self.criterion.epoch = epoch
if not self.is_slave:
print('[Epoch {} / lr {:.2e}]'.format(
epoch, self.optimizer.get_lr()
))
total=len(self.loaders[self.mode])
acc=0.0
if self.args.distributed:
self.loaders[self.mode].sampler.set_epoch(epoch)
if self.is_slave:
tq = self.loaders[self.mode]
else:
tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
buffer=[0.0]*self.recurrence
torch.set_grad_enabled(True)
for idx, batch in enumerate(tq):
self.optimizer.zero_grad()
input, target = data.common.to(
batch[0], batch[1], device=self.device, dtype=self.dtype)
# print(input[0].max(),input[0].min(),target[0].max(),target[0].min())
# exit(1)
with amp.autocast(self.args.amp):
output = self.model(input)
loss = self.criterion(output, target)
for i in range(self.recurrence):
buffer[i]+=self.criterion.buffer[i]
self.scaler.scale(loss).backward()
if self.args.clip>0:
self.scaler.unscale_(self.optimizer.G)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
self.scaler.step(self.optimizer.G)
self.scaler.update()
if isinstance(tq, tqdm):
tq.set_description(self.criterion.get_loss_desc())
if not self.is_slave:
rgb_range=self.args.rgb_range
for i in range(len(output)):
grid=torchvision.utils.make_grid(output[i])
self.writter.add_image(f"Output{i}",grid/rgb_range,epoch)
self.writter.add_scalar(f"Loss{i}",buffer[i],epoch)
self.writter.add_image("Input",torchvision.utils.make_grid(input[0])/rgb_range,epoch)
self.writter.add_image("Target",torchvision.utils.make_grid(target[0])/rgb_range,epoch)
self.criterion.normalize()
if isinstance(tq, tqdm):
tq.set_description(self.criterion.get_loss_desc())
tq.display(pos=-1) # overwrite with synchronized loss
self.criterion.step()
self.optimizer.schedule(self.criterion.get_last_loss())
if self.args.rank == 0:
self.save(epoch)
return
def evaluate(self, epoch, mode='val'):
self.mode = mode
self.epoch = epoch
self.model.eval()
self.model.to(dtype=self.dtype_eval)
if mode == 'val':
self.criterion.validate()
elif mode == 'test':
self.criterion.test()
self.criterion.epoch = epoch
self.imsaver.join_background()
if self.is_slave:
tq = self.loaders[self.mode]
else:
tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
compute_loss = True
torch.set_grad_enabled(False)
for idx, batch in enumerate(tq):
input, target = data.common.to(
batch[0], batch[1], device=self.device, dtype=self.dtype_eval)
with amp.autocast(self.args.amp):
output = self.model(input)
# if self.args.rgb_range==1:
# output=output*255
# target=target*255
if mode == 'demo': # remove padded part
pad_width = batch[2]
output[0], _ = data.common.pad(output[0], pad_width=pad_width, negative=True)
if isinstance(batch[1], torch.BoolTensor):
compute_loss = False
if compute_loss:
self.criterion(output, target)
if isinstance(tq, tqdm):
tq.set_description(self.criterion.get_loss_desc())
if self.args.save_results != 'none':
if isinstance(output, (list, tuple)):
result = output[-1] # select last output in a pyramid
elif isinstance(output, torch.Tensor):
result = output
names = batch[-1]
if self.args.save_results == 'part' and compute_loss: # save all when GT not available
indices = batch[-2]
save_ids = [save_id for save_id, idx in enumerate(indices) if idx % 10 == 0]
result = result[save_ids]
names = [names[save_id] for save_id in save_ids]
self.imsaver.save_image(result, names)
if compute_loss:
self.criterion.normalize()
if isinstance(tq, tqdm):
tq.set_description(self.criterion.get_loss_desc())
tq.display(pos=-1) # overwrite with synchronized loss
self.criterion.step()
if self.args.rank == 0:
self.save()
self.imsaver.end_background()
def validate(self, epoch):
self.evaluate(epoch, 'val')
return
def test(self, epoch):
self.evaluate(epoch, 'test')
return
def fill_evaluation(self, epoch, mode=None, force=False):
if epoch <= 0:
return
if mode is not None:
self.mode = mode
do_eval = force
if not force:
loss_missing = epoch not in self.criterion.loss_stat[self.mode]['Total'] # should it switch to all loss types?
metric_missing = False
for metric_type in self.criterion.metric:
if epoch not in self.criterion.metric_stat[mode][metric_type]:
metric_missing = True
do_eval = loss_missing or metric_missing
if do_eval:
try:
self.load(epoch)
self.evaluate(epoch, self.mode)
except:
# print('saved model/optimizer at epoch {} not found!'.format(epoch))
pass
return