|
import os.path |
|
import sys |
|
|
|
base_dir = '..' |
|
sys.path.append(base_dir) |
|
|
|
from Trainer import Trainer |
|
from TranslatorTrainer import TranslatorTrainer |
|
from dataset import GridDataset, CharMap |
|
|
|
WORD_TOKENIZE = False |
|
|
|
PHONEME_FILTER_PREV = False |
|
BEAM_SIZE = 0 |
|
|
|
|
|
lipnet_weights = 'weights/phoneme-231201-2218/I119000-L00001-W00000-C00000.pt' |
|
|
|
if WORD_TOKENIZE: |
|
translator_weights = 'weights/translate-231204-1652/I160-L00047-W00000.pt' |
|
else: |
|
|
|
|
|
translator_weights = 'weights/translate-231204-2227/I860-L00000-W00000.pt' |
|
|
|
lipnet_predictor = Trainer( |
|
write_logs=False, base_dir=base_dir, |
|
num_workers=0, char_map=CharMap.phonemes |
|
) |
|
lipnet_predictor.load_weights(lipnet_weights) |
|
lipnet_predictor.load_datasets() |
|
dataset = lipnet_predictor.test_dataset |
|
|
|
phoneme_translator = TranslatorTrainer( |
|
write_logs=False, base_dir=base_dir, word_tokenize=WORD_TOKENIZE |
|
) |
|
phoneme_translator.load_weights(os.path.join( |
|
base_dir, translator_weights |
|
)) |
|
|
|
""" |
|
new_phonemes = GridDataset.text_to_phonemes("Do you like fries") |
|
print("PRE_REV_TRANSLATE", [new_phonemes]) |
|
pred_text = phoneme_translator.translate(new_phonemes) |
|
print("AFT_REV_TRANSLATE", pred_text) |
|
|
|
phoneme_sentence = 'B-IH1-N B-L-UW1 AE1-T EH1-F TH-R-IY1 S-UW1-N' |
|
pred_text = phoneme_translator.translate(phoneme_sentence) |
|
print(f'PRED_TEXT: [{pred_text}]') |
|
""" |
|
|
|
total_samples = 1000 |
|
total_wer = 0 |
|
num_correct = 0 |
|
num_phonemes_correct = 0 |
|
|
|
|
|
video_path = '/media/milselarch/47FC4BC577667AAD/LRS2/lrs2_v1/mvlrs_v1/main/5535423430009926848/00002.mp4' |
|
|
|
vid = dataset.process_vid(video_path) |
|
pred_phonemes_sentence = lipnet_predictor.predict_video(vid)[0] |
|
print('PRED PHONEMES', pred_phonemes_sentence) |
|
|
|
pred_text = phoneme_translator.translate( |
|
pred_phonemes_sentence, beam_size=BEAM_SIZE |
|
) |
|
|
|
avg_wer = total_wer / total_samples |
|
print('PRED TEXT =', pred_text) |
|
print(f'{num_correct}/{total_samples} samples correct') |
|
print(f'{num_phonemes_correct}/{total_samples} phoneme samples correct') |
|
print(f'average WER: {avg_wer}') |