|
import torch.nn as nn |
|
import functools |
|
import torch.optim as optim |
|
import options as opt |
|
import time |
|
|
|
from helpers import * |
|
from dataset import GridDataset, CharMap |
|
from datetime import datetime as Datetime |
|
from models.LipNet import LipNet |
|
from tqdm.auto import tqdm |
|
from PauseChecker import PauseChecker |
|
from torch.utils.data import DataLoader |
|
from torch.multiprocessing import Manager |
|
from BaseTrainer import BaseTrainer |
|
|
|
|
|
class Trainer(BaseTrainer): |
|
def __init__( |
|
self, name=opt.run_name, write_logs=True, |
|
num_workers=None, base_dir='', char_map=opt.char_map, |
|
pre_gru_repeats=None |
|
): |
|
super().__init__(name=name, base_dir=base_dir) |
|
|
|
images_dir = opt.images_dir |
|
if opt.use_lip_crops: |
|
images_dir = opt.crop_images_dir |
|
if num_workers is None: |
|
num_workers = opt.num_workers |
|
if pre_gru_repeats is None: |
|
pre_gru_repeats = opt.pre_gru_repeats |
|
|
|
assert pre_gru_repeats >= 1 |
|
assert isinstance(pre_gru_repeats, int) |
|
|
|
self.images_dir = images_dir |
|
self.num_workers = num_workers |
|
self.pre_gru_repeats = pre_gru_repeats |
|
self.char_map = char_map |
|
|
|
manager = Manager() |
|
if opt.cache_videos: |
|
shared_dict = manager.dict() |
|
else: |
|
shared_dict = None |
|
|
|
self.shared_dict = shared_dict |
|
self.dataset_kwargs = self.get_dataset_kwargs( |
|
shared_dict=shared_dict, base_dir=self.base_dir, |
|
char_map=self.char_map |
|
) |
|
|
|
self.best_test_loss = float('inf') |
|
self.train_dataset = None |
|
self.test_dataset = None |
|
self.model = None |
|
self.net = None |
|
|
|
if write_logs: |
|
self.init_tensorboard() |
|
|
|
def load_datasets(self): |
|
if self.train_dataset is None: |
|
self.train_dataset = GridDataset( |
|
**self.dataset_kwargs, phase='train', |
|
file_list=opt.train_list |
|
) |
|
if self.test_dataset is None: |
|
self.test_dataset = GridDataset( |
|
**self.dataset_kwargs, phase='test', |
|
file_list=opt.val_list |
|
) |
|
|
|
def create_model(self): |
|
output_classes = len(self.train_dataset.get_char_mapping()) |
|
|
|
if self.model is None: |
|
self.model = LipNet( |
|
output_classes=output_classes, |
|
pre_gru_repeats=self.pre_gru_repeats |
|
) |
|
self.model = self.model.cuda() |
|
if self.net is None: |
|
self.net = nn.DataParallel(self.model).cuda() |
|
|
|
def load_weights(self, weights_path): |
|
self.load_datasets() |
|
self.create_model() |
|
|
|
weights_path = os.path.join(self.base_dir, weights_path) |
|
pretrained_dict = torch.load(weights_path) |
|
model_dict = self.model.state_dict() |
|
pretrained_dict = { |
|
k: v for k, v in pretrained_dict.items() if |
|
k in model_dict.keys() and v.size() == model_dict[k].size() |
|
} |
|
|
|
missed_params = [ |
|
k for k, v in model_dict.items() |
|
if k not in pretrained_dict.keys() |
|
] |
|
|
|
print('loaded params/tot params: {}/{}'.format( |
|
len(pretrained_dict), len(model_dict) |
|
)) |
|
print('miss matched params:{}'.format(missed_params)) |
|
model_dict.update(pretrained_dict) |
|
self.model.load_state_dict(model_dict) |
|
|
|
@staticmethod |
|
def make_date_stamp(): |
|
return Datetime.now().strftime("%y%m%d-%H%M") |
|
|
|
@staticmethod |
|
def dataset2dataloader( |
|
dataset, num_workers, shuffle=True |
|
): |
|
return DataLoader( |
|
dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=shuffle, |
|
num_workers=num_workers, |
|
drop_last=False |
|
) |
|
|
|
def test(self): |
|
dataset = self.test_dataset |
|
|
|
with torch.no_grad(): |
|
print('num_test_data:{}'.format(len(dataset.data))) |
|
self.model.eval() |
|
loader = self.dataset2dataloader( |
|
dataset, shuffle=False, num_workers=self.num_workers |
|
) |
|
|
|
loss_list = [] |
|
wer = [] |
|
cer = [] |
|
crit = nn.CTCLoss(zero_infinity=True) |
|
tic = time.time() |
|
print('RUNNING VALIDATION') |
|
|
|
pbar = tqdm(loader) |
|
for (i_iter, input_sample) in enumerate(pbar): |
|
PauseChecker.check() |
|
|
|
vid = input_sample.get('vid').cuda() |
|
vid_len = input_sample.get('vid_len').cuda() |
|
txt, txt_len = self.extract_char_output(input_sample) |
|
y = self.net(vid) |
|
|
|
|
|
assert ( |
|
self.pre_gru_repeats * vid_len.view(-1) > |
|
2 * txt_len.view(-1) |
|
).all() |
|
|
|
loss = crit( |
|
y.transpose(0, 1).log_softmax(-1), txt, |
|
self.pre_gru_repeats * vid_len.view(-1), |
|
txt_len.view(-1) |
|
).detach().cpu().numpy() |
|
|
|
loss_list.append(loss) |
|
pred_txt = dataset.ctc_decode(y) |
|
truth_txt = [ |
|
dataset.arr2txt(txt[_], start=1) |
|
for _ in range(txt.size(0)) |
|
] |
|
|
|
wer.extend(dataset.wer(pred_txt, truth_txt)) |
|
cer.extend(dataset.cer(pred_txt, truth_txt)) |
|
|
|
if i_iter % opt.display == 0: |
|
v = 1.0 * (time.time() - tic) / (i_iter + 1) |
|
eta = v * (len(loader) - i_iter) / 3600.0 |
|
|
|
self.log_pred_texts(pred_txt, truth_txt, sub_samples=10) |
|
print('test_iter={},eta={},wer={},cer={}'.format( |
|
i_iter, eta, np.array(wer).mean(), |
|
np.array(cer).mean() |
|
)) |
|
print(''.join(161 * '-')) |
|
|
|
return ( |
|
np.array(loss_list).mean(), np.array(wer).mean(), |
|
np.array(cer).mean() |
|
) |
|
|
|
def extract_char_output(self, input_sample): |
|
""" |
|
extract output character sequence from input_sample |
|
output character sequence is text if char_map is CharMap.letters |
|
output character sequence is phonemes if char_map is CharMap.phonemes |
|
""" |
|
if self.char_map == CharMap.letters: |
|
txt = input_sample.get('txt').cuda() |
|
txt_len = input_sample.get('txt_len').cuda() |
|
elif self.char_map == CharMap.phonemes: |
|
txt = input_sample.get('phonemes').cuda() |
|
txt_len = input_sample.get('phonemes_len').cuda() |
|
elif self.char_map == CharMap.cmu_phonemes: |
|
txt = input_sample.get('cmu_phonemes').cuda() |
|
txt_len = input_sample.get('cmu_phonemes_len').cuda() |
|
else: |
|
raise ValueError(f'UNSUPPORTED CHAR_MAP: {self.char_map}') |
|
|
|
return txt, txt_len |
|
|
|
def train(self): |
|
self.load_datasets() |
|
self.create_model() |
|
|
|
dataset = self.train_dataset |
|
loader = self.dataset2dataloader( |
|
dataset, num_workers=self.num_workers |
|
) |
|
""" |
|
optimizer = optim.Adam( |
|
self.model.parameters(), lr=opt.base_lr, |
|
weight_decay=0., amsgrad=True |
|
) |
|
""" |
|
optimizer = optim.RMSprop( |
|
self.model.parameters(), lr=opt.base_lr |
|
) |
|
|
|
print('num_train_data:{}'.format(len(dataset.data))) |
|
|
|
|
|
crit = nn.CTCLoss(zero_infinity=True) |
|
tic = time.time() |
|
|
|
train_wer = [] |
|
self.best_test_loss = float('inf') |
|
log_scalar = functools.partial(self.log_scalar, label='train') |
|
|
|
for epoch in range(opt.max_epoch): |
|
print(f'RUNNING EPOCH {epoch}') |
|
|
|
pbar = tqdm(loader) |
|
for (i_iter, input_sample) in enumerate(pbar): |
|
PauseChecker.check() |
|
|
|
self.model.train() |
|
vid = input_sample.get('vid').cuda() |
|
vid_len = input_sample.get('vid_len').cuda() |
|
txt, txt_len = self.extract_char_output(input_sample) |
|
|
|
optimizer.zero_grad() |
|
y = self.net(vid) |
|
assert not contains_nan_or_inf(y) |
|
assert ( |
|
self.pre_gru_repeats * vid_len.view(-1) > |
|
2 * txt_len.view(-1) |
|
).all() |
|
|
|
loss = crit( |
|
y.transpose(0, 1).log_softmax(-1), txt, |
|
self.pre_gru_repeats * vid_len.view(-1), |
|
txt_len.view(-1) |
|
) |
|
|
|
if contains_nan_or_inf(loss): |
|
print(f'LOSS IS INVALID. SKIPPING {i_iter}') |
|
|
|
|
|
continue |
|
|
|
loss.backward() |
|
params = self.model.parameters() |
|
|
|
if any(torch.isnan(p.grad).any() for p in params): |
|
optimizer.zero_grad() |
|
print('SKIPPING NAN GRADS') |
|
continue |
|
|
|
if opt.is_optimize: |
|
optimizer.step() |
|
|
|
assert not contains_nan_or_inf(self.model.conv1.weight) |
|
tot_iter = i_iter + epoch * len(loader) |
|
pred_txt = dataset.ctc_decode(y) |
|
truth_txt = [ |
|
dataset.arr2txt(txt[_], start=1) |
|
for _ in range(txt.size(0)) |
|
] |
|
train_wer.extend(dataset.wer(pred_txt, truth_txt)) |
|
|
|
if tot_iter % opt.display == 0: |
|
v = 1.0 * (time.time() - tic) / (tot_iter + 1) |
|
eta = (len(loader) - i_iter) * v / 3600.0 |
|
wer = np.array(train_wer).mean() |
|
|
|
log_scalar('loss', loss, tot_iter) |
|
log_scalar('wer', wer, tot_iter) |
|
|
|
self.log_pred_texts(pred_txt, truth_txt, sub_samples=3) |
|
print('epoch={},tot_iter={},eta={},loss={},train_wer={}' |
|
.format( |
|
epoch, tot_iter, eta, loss, |
|
np.array(train_wer).mean() |
|
) |
|
) |
|
print(''.join(161 * '-')) |
|
|
|
if (tot_iter > 0) and (tot_iter % opt.test_step == 0): |
|
|
|
self.run_test(tot_iter, optimizer) |
|
|
|
@staticmethod |
|
def log_pred_texts(pred_txt, truth_txt, pad=80, sub_samples=None): |
|
line_length = 2 * pad + 1 |
|
print(''.join(line_length * '-')) |
|
print('{:<{pad}}|{:>{pad}}'.format( |
|
'predict', 'truth', pad=pad |
|
)) |
|
|
|
print(''.join(line_length * '-')) |
|
zipped_samples = list(zip(pred_txt, truth_txt)) |
|
if sub_samples is not None: |
|
zipped_samples = zipped_samples[:sub_samples] |
|
|
|
for (predict, truth) in zipped_samples: |
|
print('{:<{pad}}|{:>{pad}}'.format( |
|
predict, truth, pad=pad |
|
)) |
|
|
|
print(''.join(line_length * '-')) |
|
|
|
def run_test(self, tot_iter, optimizer): |
|
log_scalar = functools.partial(self.log_scalar, label='test') |
|
|
|
(loss, wer, cer) = self.test() |
|
print('i_iter={},lr={},loss={},wer={},cer={}'.format( |
|
tot_iter, show_lr(optimizer), loss, wer, cer |
|
)) |
|
log_scalar('loss', loss, tot_iter) |
|
log_scalar('wer', wer, tot_iter) |
|
log_scalar('cer', cer, tot_iter) |
|
|
|
if loss < self.best_test_loss: |
|
print(f'NEW BEST LOSS: {loss}') |
|
self.best_test_loss = loss |
|
|
|
savename = 'I{}-L{:.4f}-W{:.4f}-C{:.4f}'.format( |
|
tot_iter, loss, wer, cer |
|
) |
|
|
|
savename = savename.replace('.', '') + '.pt' |
|
savepath = os.path.join(self.weights_dir, savename) |
|
|
|
(save_dir, name) = os.path.split(savepath) |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
|
|
torch.save(self.model.state_dict(), savepath) |
|
print(f'best model saved at {savepath}') |
|
|
|
if not opt.is_optimize: |
|
exit() |
|
|
|
def predict_sample(self, input_sample): |
|
self.model.eval() |
|
vid = input_sample.get('vid').cuda() |
|
return self.predict_video(vid) |
|
|
|
def predict_video(self, video): |
|
video = video.cuda() |
|
vid = video.unsqueeze(0) |
|
y = self.net(vid) |
|
pred_txt = self.train_dataset.ctc_decode(y) |
|
return pred_txt |