File size: 3,231 Bytes
df07554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os.path
import sys

base_dir = '..'
sys.path.append(base_dir)

from Trainer import Trainer
from TranslatorTrainer import TranslatorTrainer
from dataset import GridDataset, CharMap

# whether to use a transformer with word or character level tokenization
WORD_TOKENIZE = False
# whether to filter out consecutive phonemes
PHONEME_FILTER_PREV = False
BEAM_SIZE = 0

# lipnet_weights = 'weights/phoneme-231201-0052/I198000-L00048-W00018-C00005.pt'
# lipnet_weights = 'weights/phoneme-231201-2218/I119000-L00001-W00000-C00000.pt'
lipnet_weights = 'saved-weights/phonemes-231207-2130/I283000-L00683-W01012-C00765.pt'

if WORD_TOKENIZE:
    translator_weights = 'saved-weights/translate-231204-1652/I160-L00047-W00000.pt'
else:
    translator_weights = 'saved-weights/translate-231204-2227/I860-L00000-W00000.pt'
    # translator_weights = 'weights/translate-231202-1509/I1560-L00000-W00000.pt'
    # translator_weights = 'weights/translate-231204-1709/I220-L00042-W00000.pt'

lipnet_predictor = Trainer(
    write_logs=False, base_dir=base_dir,
    num_workers=0, char_map=CharMap.phonemes
)
lipnet_predictor.load_weights(lipnet_weights)
lipnet_predictor.load_datasets()
dataset = lipnet_predictor.test_dataset

phoneme_translator = TranslatorTrainer(
    write_logs=False, base_dir=base_dir, word_tokenize=WORD_TOKENIZE
)
phoneme_translator.load_weights(os.path.join(
    base_dir, translator_weights
))

"""
new_phonemes = GridDataset.text_to_phonemes("Do you like fries")
print("PRE_REV_TRANSLATE", [new_phonemes])
pred_text = phoneme_translator.translate(new_phonemes)
print("AFT_REV_TRANSLATE", pred_text)

phoneme_sentence = 'B-IH1-N B-L-UW1 AE1-T EH1-F TH-R-IY1 S-UW1-N'
pred_text = phoneme_translator.translate(phoneme_sentence)
print(f'PRED_TEXT: [{pred_text}]')
"""

total_samples = 1000
total_wer = 0
num_correct = 0
num_phonemes_correct = 0

for k in range(total_samples):
    sample = dataset.load_random_sample(char_map=all)
    tgt_phonemes = sample['phonemes']
    tgt_text = sample['txt']

    target_phonemes_sentence = dataset.ctc_arr2txt(
        tgt_phonemes, start=1, char_map=CharMap.phonemes,
        filter_previous=PHONEME_FILTER_PREV
    )
    target_sentence = dataset.ctc_arr2txt(
        tgt_text, start=1, char_map=CharMap.letters,
        filter_previous=False
    )

    pred_phonemes_sentence = lipnet_predictor.predict_sample(sample)[0]
    pred_text = phoneme_translator.translate(
        pred_phonemes_sentence, beam_size=BEAM_SIZE
    )
    match_phonemes = pred_phonemes_sentence == target_phonemes_sentence
    wer = dataset.get_wer(
        [pred_text], [target_sentence], char_map=CharMap.letters
    )[0]

    total_wer += wer

    correct = False
    if pred_text == target_sentence:
        correct = True
        num_correct += 1
    if pred_phonemes_sentence == target_phonemes_sentence:
        num_phonemes_correct += 1

    print(
        f'PRED-PHONEMES [{k}]',
        [pred_phonemes_sentence, target_phonemes_sentence],
        [pred_text, target_sentence], correct, wer
    )

avg_wer = total_wer / total_samples
print(f'{num_correct}/{total_samples} samples correct')
print(f'{num_phonemes_correct}/{total_samples} phoneme samples correct')
print(f'average WER: {avg_wer}')