File size: 2,328 Bytes
df07554 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os.path
import sys
base_dir = '..'
sys.path.append(base_dir)
from Trainer import Trainer
from TranslatorTrainer import TranslatorTrainer
from dataset import GridDataset, CharMap
WORD_TOKENIZE = False
# whether to filter out consecutive phonemes
PHONEME_FILTER_PREV = False
BEAM_SIZE = 0
# lipnet_weights = 'weights/phoneme-231201-0052/I198000-L00048-W00018-C00005.pt'
lipnet_weights = 'weights/phoneme-231201-2218/I119000-L00001-W00000-C00000.pt'
if WORD_TOKENIZE:
translator_weights = 'weights/translate-231204-1652/I160-L00047-W00000.pt'
else:
# translator_weights = 'weights/translate-231202-1509/I1560-L00000-W00000.pt'
# translator_weights = 'weights/translate-231204-1709/I220-L00042-W00000.pt'
translator_weights = 'weights/translate-231204-2227/I860-L00000-W00000.pt'
lipnet_predictor = Trainer(
write_logs=False, base_dir=base_dir,
num_workers=0, char_map=CharMap.phonemes
)
lipnet_predictor.load_weights(lipnet_weights)
lipnet_predictor.load_datasets()
dataset = lipnet_predictor.test_dataset
phoneme_translator = TranslatorTrainer(
write_logs=False, base_dir=base_dir, word_tokenize=WORD_TOKENIZE
)
phoneme_translator.load_weights(os.path.join(
base_dir, translator_weights
))
"""
new_phonemes = GridDataset.text_to_phonemes("Do you like fries")
print("PRE_REV_TRANSLATE", [new_phonemes])
pred_text = phoneme_translator.translate(new_phonemes)
print("AFT_REV_TRANSLATE", pred_text)
phoneme_sentence = 'B-IH1-N B-L-UW1 AE1-T EH1-F TH-R-IY1 S-UW1-N'
pred_text = phoneme_translator.translate(phoneme_sentence)
print(f'PRED_TEXT: [{pred_text}]')
"""
total_samples = 1000
total_wer = 0
num_correct = 0
num_phonemes_correct = 0
# video_path = '/home/milselarch/Videos/2023-12-07-11-02-33.mkv'
video_path = '/media/milselarch/47FC4BC577667AAD/LRS2/lrs2_v1/mvlrs_v1/main/5535423430009926848/00002.mp4'
vid = dataset.process_vid(video_path)
pred_phonemes_sentence = lipnet_predictor.predict_video(vid)[0]
print('PRED PHONEMES', pred_phonemes_sentence)
pred_text = phoneme_translator.translate(
pred_phonemes_sentence, beam_size=BEAM_SIZE
)
avg_wer = total_wer / total_samples
print('PRED TEXT =', pred_text)
print(f'{num_correct}/{total_samples} samples correct')
print(f'{num_phonemes_correct}/{total_samples} phoneme samples correct')
print(f'average WER: {avg_wer}') |