hyliu's picture
Upload folder using huggingface_hub
8ec10cf verified
import os
from importlib import import_module
import torch
from torch import nn
import torch.distributed as dist
import matplotlib.pyplot as plt
plt.switch_backend('agg') # https://github.com/matplotlib/matplotlib/issues/3466
from .metric import PSNR, SSIM
from utils import interact
class Loss(torch.nn.modules.loss._Loss):
def __init__(self, args, epoch=None, model=None, optimizer=None):
"""
input:
args.loss use '+' to sum over different loss functions
use '*' to specify the loss weight
example:
1*MSE+0.5*VGG54
loss = sum of MSE and VGG54(weight=0.5)
args.measure similar to args.loss, but without weight
example:
MSE+PSNR
measure MSE and PSNR, independently
"""
super(Loss, self).__init__()
self.args = args
self.rgb_range = args.rgb_range
self.device_type = args.device_type
self.synchronized = False
self.epoch = args.start_epoch if epoch is None else epoch
self.save_dir = args.save_dir
self.save_name = os.path.join(self.save_dir, 'loss.pt')
# self.training = True
self.validating = False
self.testing = False
self.mode = 'train'
self.modes = ('train', 'val', 'test')
# Loss
self.loss = nn.ModuleDict()
self.loss_types = []
self.weight = {}
self.loss_stat = {mode:{} for mode in self.modes}
# loss_stat[mode][loss_type][epoch] = loss_value
# loss_stat[mode]['Total'][epoch] = loss_total
for weighted_loss in args.loss.split('+'):
w, l = weighted_loss.split('*')
l = l.upper()
if l in ('ABS', 'L1'):
loss_type = 'L1'
func = nn.L1Loss()
elif l in ('MSE', 'L2'):
loss_type = 'L2'
func = nn.MSELoss()
elif l in ('ADV', 'GAN'):
loss_type = 'ADV'
m = import_module('loss.adversarial')
func = getattr(m, 'Adversarial')(args, model, optimizer)
else:
loss_type = l
m = import_module*'loss.{}'.format(l.lower())
func = getattr(m, l)(args)
self.loss_types += [loss_type]
self.loss[loss_type] = func
self.weight[loss_type] = float(w)
print('Loss function: {}'.format(args.loss))
# Metrics
self.do_measure = args.metric.lower() != 'none'
self.metric = nn.ModuleDict()
self.metric_types = []
self.metric_stat = {mode:{} for mode in self.modes}
# metric_stat[mode][metric_type][epoch] = metric_value
if self.do_measure:
for metric_type in args.metric.split(','):
metric_type = metric_type.upper()
if metric_type == 'PSNR':
metric_func = PSNR()
elif metric_type == 'SSIM':
metric_func = SSIM(args.device_type) # single precision
else:
raise NotImplementedError
self.metric_types += [metric_type]
self.metric[metric_type] = metric_func
print('Metrics: {}'.format(args.metric))
if args.start_epoch != 1:
self.load(args.start_epoch - 1)
for mode in self.modes:
for loss_type in self.loss:
if loss_type not in self.loss_stat[mode]:
self.loss_stat[mode][loss_type] = {} # initialize loss
if 'Total' not in self.loss_stat[mode]:
self.loss_stat[mode]['Total'] = {}
if self.do_measure:
for metric_type in self.metric:
if metric_type not in self.metric_stat[mode]:
self.metric_stat[mode][metric_type] = {}
self.count = 0
self.count_m = 0
self.to(args.device, dtype=args.dtype)
def train(self, mode=True):
super(Loss, self).train(mode)
if mode:
self.validating = False
self.testing = False
self.mode = 'train'
else: # default test mode
self.validating = False
self.testing = True
self.mode = 'test'
def validate(self):
super(Loss, self).eval()
# self.training = False
self.validating = True
self.testing = False
self.mode = 'val'
def test(self):
super(Loss, self).eval()
# self.training = False
self.validating = False
self.testing = True
self.mode = 'test'
def forward(self, input, target):
self.synchronized = False
loss = 0
def _ms_forward(input, target, func):
if isinstance(input, (list, tuple)): # loss for list output
_loss = []
for (input_i, target_i) in zip(input, target):
_loss += [func(input_i, target_i)]
return sum(_loss)
elif isinstance(input, dict): # loss for dict output
_loss = []
for key in input:
_loss += [func(input[key], target[key])]
return sum(_loss)
else: # loss for tensor output
return func(input, target)
# initialize
if self.count == 0:
for loss_type in self.loss_types:
self.loss_stat[self.mode][loss_type][self.epoch] = 0
self.loss_stat[self.mode]['Total'][self.epoch] = 0
if isinstance(input, list):
count = input[0].shape[0]
else: # Tensor
count = input.shape[0] # batch size
isnan = False
for loss_type in self.loss_types:
if loss_type == 'ADV':
_loss = self.loss[loss_type](input[0], target[0], self.training) * self.weight[loss_type]
else:
_loss = _ms_forward(input, target, self.loss[loss_type]) * self.weight[loss_type]
if torch.isnan(_loss):
isnan = True # skip recording (will also be skipped at backprop)
else:
self.loss_stat[self.mode][loss_type][self.epoch] += _loss.item() * count
self.loss_stat[self.mode]['Total'][self.epoch] += _loss.item() * count
loss += _loss
if not isnan:
self.count += count
if not self.training and self.do_measure:
self.measure(input, target)
return loss
def measure(self, input, target):
if isinstance(input, (list, tuple)):
self.measure(input[0], target[0])
return
elif isinstance(input, dict):
first_key = list(input.keys())[0]
self.measure(input[first_key], target[first_key])
return
else:
pass
if self.count_m == 0:
for metric_type in self.metric_stat[self.mode]:
self.metric_stat[self.mode][metric_type][self.epoch] = 0
if isinstance(input, list):
count = input[0].shape[0]
else: # Tensor
count = input.shape[0] # batch size
for metric_type in self.metric_stat[self.mode]:
input = input.clamp(0, self.rgb_range) # not in_place
if self.rgb_range == 255:
input.round_()
_metric = self.metric[metric_type](input, target)
self.metric_stat[self.mode][metric_type][self.epoch] += _metric.item() * count
self.count_m += count
return
def normalize(self):
if self.args.distributed:
dist.barrier()
if not self.synchronized:
self.all_reduce()
if self.count > 0:
for loss_type in self.loss_stat[self.mode]: # including 'Total'
self.loss_stat[self.mode][loss_type][self.epoch] /= self.count
self.count = 0
if self.count_m > 0:
for metric_type in self.metric_stat[self.mode]:
self.metric_stat[self.mode][metric_type][self.epoch] /= self.count_m
self.count_m = 0
return
def all_reduce(self, epoch=None):
# synchronize loss for distributed GPU processes
if epoch is None:
epoch = self.epoch
def _reduce_value(value, ReduceOp=dist.ReduceOp.SUM):
value_tensor = torch.Tensor([value]).to(self.args.device, self.args.dtype, non_blocking=True)
dist.all_reduce(value_tensor, ReduceOp, async_op=False)
value = value_tensor.item()
del value_tensor
return value
dist.barrier()
if self.count > 0: # I assume this should be true
self.count = _reduce_value(self.count, dist.ReduceOp.SUM)
for loss_type in self.loss_stat[self.mode]:
self.loss_stat[self.mode][loss_type][epoch] = _reduce_value(
self.loss_stat[self.mode][loss_type][epoch],
dist.ReduceOp.SUM
)
if self.count_m > 0:
self.count_m = _reduce_value(self.count_m, dist.ReduceOp.SUM)
for metric_type in self.metric_stat[self.mode]:
self.metric_stat[self.mode][metric_type][epoch] = _reduce_value(
self.metric_stat[self.mode][metric_type][epoch],
dist.ReduceOp.SUM
)
self.synchronized = True
return
def print_metrics(self):
print(self.get_metric_desc())
return
def get_last_loss(self):
return self.loss_stat[self.mode]['Total'][self.epoch]
def get_loss_desc(self):
if self.mode == 'train':
desc_prefix = 'Train'
elif self.mode == 'val':
desc_prefix = 'Validation'
else:
desc_prefix = 'Test'
loss = self.loss_stat[self.mode]['Total'][self.epoch]
if self.count > 0:
loss /= self.count
desc = '{} Loss: {:.1f}'.format(desc_prefix, loss)
if self.mode in ('val', 'test'):
metric_desc = self.get_metric_desc()
desc = '{}{}'.format(desc, metric_desc)
return desc
def get_metric_desc(self):
desc = ''
for metric_type in self.metric_stat[self.mode]:
measured = self.metric_stat[self.mode][metric_type][self.epoch]
if self.count_m > 0:
measured /= self.count_m
if metric_type == 'PSNR':
desc += ' {}: {:2.2f}'.format(metric_type, measured)
elif metric_type == 'SSIM':
desc += ' {}: {:1.4f}'.format(metric_type, measured)
else:
desc += ' {}: {:2.4f}'.format(metric_type, measured)
return desc
def step(self, plot_name=None):
self.normalize()
self.plot(plot_name)
if not self.training and self.do_measure:
# self.print_metrics()
self.plot_metric()
# self.epoch += 1
return
def save(self):
state = {
'loss_stat': self.loss_stat,
'metric_stat': self.metric_stat,
}
torch.save(state, self.save_name)
return
def load(self, epoch=None):
print('Loading loss record from {}'.format(self.save_name))
if os.path.exists(self.save_name):
state = torch.load(self.save_name, map_location=self.args.device)
self.loss_stat = state['loss_stat']
if 'metric_stat' in state:
self.metric_stat = state['metric_stat']
else:
pass
else:
print('no loss record found for {}!'.format(self.save_name))
if epoch is not None:
self.epoch = epoch
return
def plot(self, plot_name=None, metric=False):
self.plot_loss(plot_name)
if metric:
self.plot_metric(plot_name)
# else:
# self.plot_loss(plot_name)
return
def plot_loss(self, plot_name=None):
if plot_name is None:
plot_name = os.path.join(self.save_dir, "{}_loss.pdf".format(self.mode))
title = "{} loss".format(self.mode)
fig = plt.figure()
plt.title(title)
plt.xlabel('epochs')
plt.ylabel('loss')
plt.grid(True, linestyle=':')
for loss_type, loss_record in self.loss_stat[self.mode].items(): # including Total
axis = sorted([epoch for epoch in loss_record.keys() if epoch <= self.epoch])
value = [self.loss_stat[self.mode][loss_type][epoch] for epoch in axis]
label = loss_type
plt.plot(axis, value, label=label)
plt.xlim(0, self.epoch)
plt.legend()
plt.savefig(plot_name)
plt.close(fig)
return
def plot_metric(self, plot_name=None):
# assume there are only max 2 metrics
if plot_name is None:
plot_name = os.path.join(self.save_dir, "{}_metric.pdf".format(self.mode))
title = "{} metrics".format(self.mode)
fig, ax1 = plt.subplots()
plt.title(title)
plt.grid(True, linestyle=':')
ax1.set_xlabel('epochs')
plots = None
for metric_type, metric_record in self.metric_stat[self.mode].items():
axis = sorted([epoch for epoch in metric_record.keys() if epoch <= self.epoch])
value = [metric_record[epoch] for epoch in axis]
label = metric_type
if metric_type == 'PSNR':
ax = ax1
color='C0'
elif metric_type == 'SSIM':
ax2 = ax1.twinx()
ax = ax2
color='C1'
ax.set_ylabel(metric_type)
if plots is None:
plots = ax.plot(axis, value, label=label, color=color)
else:
plots += ax.plot(axis, value, label=label, color=color)
labels = [plot.get_label() for plot in plots]
plt.legend(plots, labels)
plt.xlim(0, self.epoch)
plt.savefig(plot_name)
plt.close(fig)
return
def sort(self):
# sort the loss/metric record
for mode in self.modes:
for loss_type, loss_epochs in self.loss_stat[mode].items():
self.loss_stat[mode][loss_type] = {epoch: loss_epochs[epoch] for epoch in sorted(loss_epochs)}
for metric_type, metric_epochs in self.metric_stat[mode].items():
self.metric_stat[mode][metric_type] = {epoch: metric_epochs[epoch] for epoch in sorted(metric_epochs)}
return self