|
import os |
|
import sys |
|
import time |
|
|
|
sys.path.append('../models') |
|
|
|
import torch |
|
import functools |
|
import options as opt |
|
|
|
from torch import optim |
|
from tqdm.auto import tqdm |
|
|
|
from PauseChecker import PauseChecker |
|
from Trainer import Trainer |
|
from models.LipNetPlus import LipNetPlus |
|
from TranslatorTrainer import TranslatorTrainer |
|
from dataset import GridDataset, CharMap, Datasets |
|
from helpers import contains_nan_or_inf |
|
from models.PhonemeTransformer import * |
|
from helpers import * |
|
|
|
|
|
class TransformerTrainer(Trainer, TranslatorTrainer): |
|
def __init__( |
|
self, batch_size=opt.batch_size, word_tokenize=False, |
|
dataset_type: Datasets = opt.dataset, embeds_size=256, |
|
vocab_files=None, write_logs=True, |
|
input_char_map=CharMap.phonemes, |
|
output_char_map=CharMap.letters, |
|
name='embeds-transformer-v2', |
|
**kwargs |
|
): |
|
super().__init__(**kwargs, name=name) |
|
|
|
self.batch_size = batch_size |
|
self.word_tokenize = word_tokenize |
|
self.input_char_map = input_char_map |
|
self.output_char_map = output_char_map |
|
self.dataset_type = dataset_type |
|
self.embeds_size = embeds_size |
|
|
|
self.text_tokenizer = functools.partial( |
|
GridDataset.tokenize_text, word_tokenize=word_tokenize |
|
) |
|
self.device = torch.device( |
|
'cuda' if torch.cuda.is_available() else 'cpu' |
|
) |
|
|
|
if vocab_files is None: |
|
vocabs = self.load_vocabs(self.base_dir) |
|
self.phonemes_vocab, self.text_vocab = vocabs |
|
else: |
|
phonemes_vocab_path, text_vocab_path = vocab_files |
|
self.phonemes_vocab = torch.load(phonemes_vocab_path) |
|
self.text_vocab = torch.load(text_vocab_path) |
|
|
|
self.model = None |
|
self.optimizer = None |
|
self.best_test_loss = float('inf') |
|
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) |
|
|
|
""" |
|
self.phonemes_encoder = self.sequential_transforms( |
|
GridDataset.tokenize_phonemes, self.phonemes_vocab, |
|
self.tensor_transform |
|
) |
|
""" |
|
self.text_encoder = self.sequential_transforms( |
|
self.text_tokenizer, self.text_vocab, |
|
self.tensor_transform |
|
) |
|
|
|
if write_logs: |
|
self.init_tensorboard() |
|
|
|
def create_model(self): |
|
if self.model is None: |
|
output_classes = len(self.train_dataset.get_char_mapping()) |
|
|
|
self.model = LipNetPlus( |
|
output_classes=output_classes, |
|
pre_gru_repeats=self.pre_gru_repeats, |
|
embeds_size=self.embeds_size, |
|
output_vocab_size=len(self.text_vocab) |
|
) |
|
self.model = self.model.cuda() |
|
if self.net is None: |
|
self.net = nn.DataParallel(self.model).cuda() |
|
|
|
def load_datasets(self): |
|
if self.train_dataset is None: |
|
self.train_dataset = GridDataset( |
|
**self.dataset_kwargs, phase='train', |
|
file_list=opt.train_list, |
|
sample_all_props=True |
|
) |
|
if self.test_dataset is None: |
|
self.test_dataset = GridDataset( |
|
**self.dataset_kwargs, phase='test', |
|
file_list=opt.val_list, |
|
sample_all_props=True |
|
) |
|
|
|
def train(self): |
|
self.load_datasets() |
|
self.create_model() |
|
|
|
dataset = self.train_dataset |
|
loader = self.dataset2dataloader( |
|
dataset, num_workers=self.num_workers |
|
) |
|
""" |
|
optimizer = optim.Adam( |
|
self.model.parameters(), lr=opt.base_lr, |
|
weight_decay=0., amsgrad=True |
|
) |
|
""" |
|
optimizer = optim.RMSprop( |
|
self.model.parameters(), lr=opt.base_lr |
|
) |
|
|
|
print('num_train_data:{}'.format(len(dataset.data))) |
|
|
|
|
|
tic = time.time() |
|
|
|
self.best_test_loss = float('inf') |
|
log_scalar = functools.partial(self.log_scalar, label='train') |
|
|
|
for epoch in range(opt.max_epoch): |
|
print(f'RUNNING EPOCH {epoch}') |
|
train_wer = [] |
|
|
|
pbar = tqdm(loader) |
|
for (i_iter, input_sample) in enumerate(pbar): |
|
PauseChecker.check() |
|
|
|
self.model.train() |
|
vid = input_sample.get('vid').cuda() |
|
|
|
|
|
batch_arr_sentences = input_sample['txt_anno'] |
|
batch_arr_sentences = np.array(batch_arr_sentences) |
|
|
|
_, batch_size = batch_arr_sentences.shape |
|
batch_sentences = [ |
|
''.join(batch_arr_sentences[:, k]).strip() |
|
for k in range(batch_size) |
|
] |
|
|
|
tgt = self.collate_tgt_fn(batch_sentences) |
|
tgt = tgt.to(self.device) |
|
tgt_input = tgt[:-1, :] |
|
|
|
with torch.no_grad(): |
|
gru_output = self.model.forward_gru(vid) |
|
y = self.model.predict_from_gru_out(gru_output) |
|
|
|
src_embeds = self.model.make_src_embeds(gru_output) |
|
transformer_out = self.make_transformer_embeds( |
|
dataset, src_embeds, y, batch_size=batch_size |
|
) |
|
|
|
transformer_src_embeds, src_idx_arr = transformer_out |
|
transformer_src_embeds = transformer_src_embeds.to(self.device) |
|
src_idx_arr = src_idx_arr.to(self.device) |
|
max_seq_len, batch_size = src_idx_arr.shape |
|
|
|
( |
|
src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask |
|
) = create_mask( |
|
src_idx_arr, tgt_input, self.device |
|
) |
|
|
|
logits = self.model.seq_forward( |
|
transformer_src_embeds, tgt_input, src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask, src_padding_mask |
|
) |
|
|
|
optimizer.zero_grad() |
|
|
|
tgt_out = tgt[1:, :] |
|
loss = self.loss_fn( |
|
logits.reshape(-1, logits.shape[-1]), |
|
tgt_out.reshape(-1) |
|
) |
|
|
|
tot_iter = i_iter + epoch * len(loader) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
token_indices = torch.argmax(probs, dim=-1) |
|
|
|
|
|
|
|
gap = ' ' if self.word_tokenize else '' |
|
|
|
pred_sentences = self.batch_indices_to_text( |
|
token_indices, batch_size=batch_size, gap=gap |
|
) |
|
wer = np.mean(GridDataset.get_wer( |
|
pred_sentences, batch_sentences, |
|
char_map=self.output_char_map |
|
)) |
|
train_wer.append(wer) |
|
|
|
if tot_iter % opt.display == 0: |
|
v = 1.0 * (time.time() - tic) / (tot_iter + 1) |
|
eta = (len(loader) - i_iter) * v / 3600.0 |
|
wer = np.array(train_wer).mean() |
|
|
|
log_scalar('loss', loss, tot_iter) |
|
log_scalar('wer', wer, tot_iter) |
|
self.log_pred_texts( |
|
pred_sentences, batch_sentences, sub_samples=3 |
|
) |
|
|
|
print('epoch={},tot_iter={},eta={},loss={},train_wer={}' |
|
.format( |
|
epoch, tot_iter, eta, loss, |
|
np.array(train_wer).mean() |
|
)) |
|
print(''.join(161 * '-')) |
|
|
|
if (tot_iter > -1) and (tot_iter % opt.test_step == 0): |
|
|
|
self.run_test(tot_iter, optimizer) |
|
|
|
def make_transformer_embeds( |
|
self, dataset, src_embeds, y, batch_size |
|
): |
|
batch_indices = dataset.ctc_decode_indices(y) |
|
filter_batch_embeds = [] |
|
|
|
pad_embed = self.model.src_tok_emb( |
|
torch.IntTensor([PAD_IDX]).to(self.device) |
|
) |
|
begin_embed = self.model.src_tok_emb( |
|
torch.IntTensor([BOS_IDX]).to(self.device) |
|
) |
|
end_embed = self.model.src_tok_emb( |
|
torch.IntTensor([EOS_IDX]).to(self.device) |
|
) |
|
max_sentence_len = max([len(x) for x in batch_indices]) |
|
|
|
|
|
|
|
transformer_src_embeds = pad_embed.expand( |
|
max_sentence_len + 2, batch_size, pad_embed.shape[1] |
|
) |
|
|
|
src_idx_mask = torch.full( |
|
transformer_src_embeds.shape[:2], PAD_IDX, |
|
dtype=torch.int |
|
) |
|
|
|
|
|
for k, sentence_indices in enumerate(batch_indices): |
|
filter_sentence_embeds = [] |
|
for sentence_index in sentence_indices: |
|
filter_sentence_embeds.append( |
|
src_embeds[sentence_index][k] |
|
) |
|
|
|
sentence_length = len(filter_sentence_embeds) |
|
filter_batch_embeds.append(filter_sentence_embeds) |
|
|
|
transformer_src_embeds[0][k] = begin_embed |
|
src_idx_mask[0][k] = UNK_IDX |
|
|
|
|
|
for i, char_embed in enumerate(filter_sentence_embeds): |
|
transformer_src_embeds[i + 1][k] = char_embed |
|
src_idx_mask[i + 1][k] = UNK_IDX |
|
|
|
transformer_src_embeds[sentence_length + 1][k] = end_embed |
|
src_idx_mask[sentence_length + 1][k] = UNK_IDX |
|
|
|
return transformer_src_embeds, src_idx_mask |
|
|
|
@staticmethod |
|
def log_pred_texts( |
|
pred_txt, truth_txt, pad=80, sub_samples=None |
|
): |
|
line_length = 2 * pad + 1 |
|
print(''.join(line_length * '-')) |
|
print('{:<{pad}}|{:>{pad}}'.format( |
|
'predict', 'truth', pad=pad |
|
)) |
|
|
|
print(''.join(line_length * '-')) |
|
zipped_samples = list(zip(pred_txt, truth_txt)) |
|
if sub_samples is not None: |
|
zipped_samples = zipped_samples[:sub_samples] |
|
|
|
for (predict, truth) in zipped_samples: |
|
print('{:<{pad}}|{:>{pad}}'.format( |
|
predict, truth, pad=pad |
|
)) |
|
|
|
print(''.join(line_length * '-')) |
|
|
|
def test(self): |
|
dataset = self.test_dataset |
|
|
|
with torch.no_grad(): |
|
print('num_test_data:{}'.format(len(dataset.data))) |
|
self.model.eval() |
|
loader = self.dataset2dataloader( |
|
dataset, shuffle=False, num_workers=self.num_workers |
|
) |
|
|
|
loss_list = [] |
|
wer = [] |
|
cer = [] |
|
tic = time.time() |
|
print('RUNNING VALIDATION') |
|
|
|
pbar = tqdm(loader) |
|
for (i_iter, input_sample) in enumerate(pbar): |
|
PauseChecker.check() |
|
|
|
vid = input_sample.get('vid').cuda() |
|
batch_arr_sentences = input_sample['txt_anno'] |
|
batch_arr_sentences = np.array(batch_arr_sentences) |
|
|
|
_, batch_size = batch_arr_sentences.shape |
|
batch_sentences = [ |
|
''.join(batch_arr_sentences[:, k]).strip() |
|
for k in range(batch_size) |
|
] |
|
|
|
tgt = self.collate_tgt_fn(batch_sentences) |
|
tgt = tgt.to(self.device) |
|
tgt_input = tgt[:-1, :] |
|
|
|
with torch.no_grad(): |
|
gru_output = self.model.forward_gru(vid) |
|
y = self.model.predict_from_gru_out(gru_output) |
|
|
|
src_embeds = self.model.make_src_embeds(gru_output) |
|
transformer_out = self.make_transformer_embeds( |
|
dataset, src_embeds, y, batch_size=batch_size |
|
) |
|
|
|
transformer_src_embeds, src_idx_arr = transformer_out |
|
transformer_src_embeds = transformer_src_embeds.to(self.device) |
|
src_idx_arr = src_idx_arr.to(self.device) |
|
max_seq_len, batch_size = src_idx_arr.shape |
|
|
|
( |
|
src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask |
|
) = create_mask( |
|
src_idx_arr, tgt_input, self.device |
|
) |
|
|
|
logits = self.model.seq_forward( |
|
transformer_src_embeds, tgt_input, src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask, src_padding_mask |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
token_indices = torch.argmax(probs, dim=-1) |
|
|
|
|
|
|
|
gap = ' ' if self.word_tokenize else '' |
|
|
|
pred_sentences = self.batch_indices_to_text( |
|
token_indices, batch_size=batch_size, gap=gap |
|
) |
|
|
|
tgt_out = tgt[1:, :] |
|
loss = self.loss_fn( |
|
logits.reshape(-1, logits.shape[-1]), |
|
tgt_out.reshape(-1) |
|
) |
|
|
|
loss_item = loss.detach().cpu().numpy() |
|
loss_list.append(loss_item) |
|
|
|
wer.extend(GridDataset.get_wer( |
|
pred_sentences, batch_sentences, |
|
char_map=self.output_char_map |
|
)) |
|
cer.extend(GridDataset.get_cer( |
|
pred_sentences, batch_sentences, |
|
char_map=self.output_char_map |
|
)) |
|
|
|
if i_iter % opt.display == 0: |
|
v = 1.0 * (time.time() - tic) / (i_iter + 1) |
|
eta = v * (len(loader) - i_iter) / 3600.0 |
|
|
|
self.log_pred_texts( |
|
pred_sentences, batch_sentences, sub_samples=10 |
|
) |
|
|
|
print('test_iter={},eta={},wer={},cer={}'.format( |
|
i_iter, eta, np.array(wer).mean(), |
|
np.array(cer).mean() |
|
)) |
|
print(''.join(161 * '-')) |
|
|
|
return ( |
|
np.array(loss_list).mean(), np.array(wer).mean(), |
|
np.array(cer).mean() |
|
) |
|
|
|
def run_test(self, tot_iter, optimizer): |
|
log_scalar = functools.partial(self.log_scalar, label='test') |
|
|
|
(loss, wer, cer) = self.test() |
|
print('i_iter={},lr={},loss={},wer={},cer={}'.format( |
|
tot_iter, show_lr(optimizer), loss, wer, cer |
|
)) |
|
log_scalar('loss', loss, tot_iter) |
|
log_scalar('wer', wer, tot_iter) |
|
log_scalar('cer', cer, tot_iter) |
|
|
|
if loss < self.best_test_loss: |
|
print(f'NEW BEST LOSS: {loss}') |
|
self.best_test_loss = loss |
|
|
|
savename = 'I{}-L{:.4f}-W{:.4f}-C{:.4f}'.format( |
|
tot_iter, loss, wer, cer |
|
) |
|
|
|
savename = savename.replace('.', '') + '.pt' |
|
savepath = os.path.join(self.weights_dir, savename) |
|
|
|
(save_dir, name) = os.path.split(savepath) |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
|
|
torch.save(self.model.state_dict(), savepath) |
|
print(f'best model saved at {savepath}') |
|
|
|
if not opt.is_optimize: |
|
exit() |
|
|
|
|
|
if __name__ == '__main__': |
|
vocab_filepaths = ( |
|
'data/grid_phoneme_vocab.pth', |
|
'data/grid_text_char_vocab.pth' |
|
) |
|
""" |
|
vocab_filepaths = ( |
|
'data/lsr2_phoneme_vocab.pth', |
|
'data/lsr2_text_char_vocab.pth' |
|
) |
|
""" |
|
|
|
trainer = TransformerTrainer( |
|
word_tokenize=False, vocab_files=vocab_filepaths, |
|
input_char_map=opt.char_map, |
|
output_char_map=opt.text_char_map |
|
) |
|
|
|
if hasattr(opt, 'weights'): |
|
trainer.load_weights(opt.weights) |
|
|
|
trainer.train() |