|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
from collections import deque, defaultdict |
|
import pickle |
|
import shutil |
|
|
|
import numpy as np |
|
import paddle |
|
import paddle.nn.functional as F |
|
from paddleseg.utils import TimeAverager, calculate_eta, resume, logger |
|
|
|
from .val import evaluate |
|
|
|
|
|
def visual_in_traning(log_writer, vis_dict, step): |
|
""" |
|
Visual in vdl |
|
|
|
Args: |
|
log_writer (LogWriter): The log writer of vdl. |
|
vis_dict (dict): Dict of tensor. The shape of thesor is (C, H, W) |
|
""" |
|
for key, value in vis_dict.items(): |
|
value_shape = value.shape |
|
if value_shape[0] not in [1, 3]: |
|
value = value[0] |
|
value = value.unsqueeze(0) |
|
value = paddle.transpose(value, (1, 2, 0)) |
|
min_v = paddle.min(value) |
|
max_v = paddle.max(value) |
|
if (min_v > 0) and (max_v < 1): |
|
value = value * 255 |
|
elif (min_v < 0 and min_v >= -1) and (max_v <= 1): |
|
value = (1 + value) / 2 * 255 |
|
else: |
|
value = (value - min_v) / (max_v - min_v) * 255 |
|
|
|
value = value.astype('uint8') |
|
value = value.numpy() |
|
log_writer.add_image(tag=key, img=value, step=step) |
|
|
|
|
|
def save_best(best_model_dir, metrics_data, iter): |
|
with open(os.path.join(best_model_dir, 'best_metrics.txt'), 'w') as f: |
|
for key, value in metrics_data.items(): |
|
line = key + ' ' + str(value) + '\n' |
|
f.write(line) |
|
f.write('iter' + ' ' + str(iter) + '\n') |
|
|
|
|
|
def get_best(best_file, metrics, resume_model=None): |
|
'''Get best metrics and iter from file''' |
|
best_metrics_data = {} |
|
if os.path.exists(best_file) and (resume_model is not None): |
|
values = [] |
|
with open(best_file, 'r') as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
line = line.strip() |
|
key, value = line.split(' ') |
|
best_metrics_data[key] = eval(value) |
|
if key == 'iter': |
|
best_iter = eval(value) |
|
else: |
|
for key in metrics: |
|
best_metrics_data[key] = np.inf |
|
best_iter = -1 |
|
return best_metrics_data, best_iter |
|
|
|
|
|
def train(model, |
|
train_dataset, |
|
val_dataset=None, |
|
optimizer=None, |
|
save_dir='output', |
|
iters=10000, |
|
batch_size=2, |
|
resume_model=None, |
|
save_interval=1000, |
|
log_iters=10, |
|
log_image_iters=1000, |
|
num_workers=0, |
|
use_vdl=False, |
|
losses=None, |
|
keep_checkpoint_max=5, |
|
eval_begin_iters=None, |
|
metrics='sad'): |
|
""" |
|
Launch training. |
|
Args: |
|
model(nn.Layer): A matting model. |
|
train_dataset (paddle.io.Dataset): Used to read and process training datasets. |
|
val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets. |
|
optimizer (paddle.optimizer.Optimizer): The optimizer. |
|
save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'. |
|
iters (int, optional): How may iters to train the model. Defualt: 10000. |
|
batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2. |
|
resume_model (str, optional): The path of resume model. |
|
save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000. |
|
log_iters (int, optional): Display logging information at every log_iters. Default: 10. |
|
log_image_iters (int, optional): Log image to vdl. Default: 1000. |
|
num_workers (int, optional): Num workers for data loader. Default: 0. |
|
use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False. |
|
losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None. |
|
keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5. |
|
eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None. |
|
metrics(str|list, optional): The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn"). |
|
""" |
|
model.train() |
|
nranks = paddle.distributed.ParallelEnv().nranks |
|
local_rank = paddle.distributed.ParallelEnv().local_rank |
|
|
|
start_iter = 0 |
|
if resume_model is not None: |
|
start_iter = resume(model, optimizer, resume_model) |
|
|
|
if not os.path.isdir(save_dir): |
|
if os.path.exists(save_dir): |
|
os.remove(save_dir) |
|
os.makedirs(save_dir) |
|
|
|
if nranks > 1: |
|
|
|
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( |
|
): |
|
paddle.distributed.init_parallel_env() |
|
ddp_model = paddle.DataParallel(model) |
|
else: |
|
ddp_model = paddle.DataParallel(model) |
|
|
|
batch_sampler = paddle.io.DistributedBatchSampler( |
|
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) |
|
|
|
loader = paddle.io.DataLoader( |
|
train_dataset, |
|
batch_sampler=batch_sampler, |
|
num_workers=num_workers, |
|
return_list=True, ) |
|
|
|
if use_vdl: |
|
from visualdl import LogWriter |
|
log_writer = LogWriter(save_dir) |
|
|
|
if isinstance(metrics, str): |
|
metrics = [metrics] |
|
elif not isinstance(metrics, list): |
|
metrics = ['sad'] |
|
best_metrics_data, best_iter = get_best( |
|
os.path.join(save_dir, 'best_model', 'best_metrics.txt'), |
|
metrics, |
|
resume_model=resume_model) |
|
avg_loss = defaultdict(float) |
|
iters_per_epoch = len(batch_sampler) |
|
reader_cost_averager = TimeAverager() |
|
batch_cost_averager = TimeAverager() |
|
save_models = deque() |
|
batch_start = time.time() |
|
|
|
iter = start_iter |
|
while iter < iters: |
|
for data in loader: |
|
iter += 1 |
|
if iter > iters: |
|
break |
|
reader_cost_averager.record(time.time() - batch_start) |
|
|
|
logit_dict, loss_dict = ddp_model(data) if nranks > 1 else model( |
|
data) |
|
|
|
loss_dict['all'].backward() |
|
|
|
optimizer.step() |
|
lr = optimizer.get_lr() |
|
if isinstance(optimizer._learning_rate, |
|
paddle.optimizer.lr.LRScheduler): |
|
optimizer._learning_rate.step() |
|
model.clear_gradients() |
|
|
|
for key, value in loss_dict.items(): |
|
avg_loss[key] += value.numpy()[0] |
|
batch_cost_averager.record( |
|
time.time() - batch_start, num_samples=batch_size) |
|
|
|
if (iter) % log_iters == 0 and local_rank == 0: |
|
for key, value in avg_loss.items(): |
|
avg_loss[key] = value / log_iters |
|
remain_iters = iters - iter |
|
avg_train_batch_cost = batch_cost_averager.get_average() |
|
avg_train_reader_cost = reader_cost_averager.get_average() |
|
eta = calculate_eta(remain_iters, avg_train_batch_cost) |
|
|
|
loss_str = ' ' * 26 + '\t[LOSSES]' |
|
loss_str = loss_str |
|
for key, value in avg_loss.items(): |
|
if key != 'all': |
|
loss_str = loss_str + ' ' + key + '={:.4f}'.format( |
|
value) |
|
logger.info( |
|
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}\n{}\n" |
|
.format((iter - 1) // iters_per_epoch + 1, iter, iters, |
|
avg_loss['all'], lr, avg_train_batch_cost, |
|
avg_train_reader_cost, |
|
batch_cost_averager.get_ips_average( |
|
), eta, loss_str)) |
|
if use_vdl: |
|
for key, value in avg_loss.items(): |
|
log_tag = 'Train/' + key |
|
log_writer.add_scalar(log_tag, value, iter) |
|
|
|
log_writer.add_scalar('Train/lr', lr, iter) |
|
log_writer.add_scalar('Train/batch_cost', |
|
avg_train_batch_cost, iter) |
|
log_writer.add_scalar('Train/reader_cost', |
|
avg_train_reader_cost, iter) |
|
if iter % log_image_iters == 0: |
|
vis_dict = {} |
|
|
|
vis_dict['ground truth/img'] = data['img'][0] |
|
for key in data['gt_fields']: |
|
key = key[0] |
|
vis_dict['/'.join(['ground truth', key])] = data[ |
|
key][0] |
|
|
|
for key, value in logit_dict.items(): |
|
vis_dict['/'.join(['predict', key])] = logit_dict[ |
|
key][0] |
|
visual_in_traning( |
|
log_writer=log_writer, vis_dict=vis_dict, step=iter) |
|
|
|
for key in avg_loss.keys(): |
|
avg_loss[key] = 0. |
|
reader_cost_averager.reset() |
|
batch_cost_averager.reset() |
|
|
|
|
|
if (iter % save_interval == 0 or iter == iters) and local_rank == 0: |
|
current_save_dir = os.path.join(save_dir, |
|
"iter_{}".format(iter)) |
|
if not os.path.isdir(current_save_dir): |
|
os.makedirs(current_save_dir) |
|
paddle.save(model.state_dict(), |
|
os.path.join(current_save_dir, 'model.pdparams')) |
|
paddle.save(optimizer.state_dict(), |
|
os.path.join(current_save_dir, 'model.pdopt')) |
|
save_models.append(current_save_dir) |
|
if len(save_models) > keep_checkpoint_max > 0: |
|
model_to_remove = save_models.popleft() |
|
shutil.rmtree(model_to_remove) |
|
|
|
|
|
if eval_begin_iters is None: |
|
eval_begin_iters = iters // 2 |
|
if (iter % save_interval == 0 or iter == iters) and ( |
|
val_dataset is not None |
|
) and local_rank == 0 and iter >= eval_begin_iters: |
|
num_workers = 1 if num_workers > 0 else 0 |
|
metrics_data = evaluate( |
|
model, |
|
val_dataset, |
|
num_workers=1, |
|
print_detail=True, |
|
save_results=False, |
|
metrics=metrics) |
|
model.train() |
|
|
|
|
|
if (iter % save_interval == 0 or iter == iters) and local_rank == 0: |
|
if val_dataset is not None and iter >= eval_begin_iters: |
|
if metrics_data[metrics[0]] < best_metrics_data[metrics[0]]: |
|
best_iter = iter |
|
best_metrics_data = metrics_data.copy() |
|
best_model_dir = os.path.join(save_dir, "best_model") |
|
paddle.save( |
|
model.state_dict(), |
|
os.path.join(best_model_dir, 'model.pdparams')) |
|
save_best(best_model_dir, best_metrics_data, iter) |
|
|
|
show_list = [] |
|
for key, value in best_metrics_data.items(): |
|
show_list.append((key, value)) |
|
log_str = '[EVAL] The model with the best validation {} ({:.4f}) was saved at iter {}.'.format( |
|
show_list[0][0], show_list[0][1], best_iter) |
|
if len(show_list) > 1: |
|
log_str += " While" |
|
for i in range(1, len(show_list)): |
|
log_str = log_str + ' {}: {:.4f},'.format( |
|
show_list[i][0], show_list[i][1]) |
|
log_str = log_str[:-1] |
|
logger.info(log_str) |
|
|
|
if use_vdl: |
|
for key, value in metrics_data.items(): |
|
log_writer.add_scalar('Evaluate/' + key, value, |
|
iter) |
|
|
|
batch_start = time.time() |
|
|
|
|
|
time.sleep(0.5) |
|
if use_vdl: |
|
log_writer.close() |
|
|