|
|
|
import numpy as np |
|
import glob |
|
import time |
|
import cv2 |
|
import yaml |
|
import os |
|
import torch |
|
import glob |
|
import re |
|
import string |
|
import copy |
|
import json |
|
import random |
|
import enum |
|
import editdistance |
|
import pronouncing |
|
|
|
from torch.utils.data import Dataset |
|
|
|
import Extractor |
|
import options |
|
from cvtransforms import * |
|
from typing import List, Iterable |
|
from helpers import * |
|
|
|
|
|
class CharMap(str, enum.Enum): |
|
letters = 'letters' |
|
lsr2_text = 'lsr2_text' |
|
phonemes = 'phonemes' |
|
cmu_phonemes = 'cmu_phonemes' |
|
visemes = 'visemes' |
|
|
|
|
|
class Datasets(str, enum.Enum): |
|
GRID = 'GRID' |
|
LRS2 = 'LRS2' |
|
|
|
|
|
class GridDataset(Dataset): |
|
letters = [ |
|
' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', |
|
'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', |
|
'T', 'U', 'V', 'W', 'X', 'Y', 'Z' |
|
] |
|
lrs2_chars = [ |
|
' ', "'", '0', '1', '2', '3', '4', '5', '6', '7', '8', |
|
'9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', |
|
'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', |
|
'V', 'W', 'X', 'Y', 'Z' |
|
] |
|
|
|
phonemes = [ |
|
' ', 'AE1', 'AO1', 'D', 'JH', 'Y', 'P', 'AH0', 'OW1', 'G', |
|
'AY1', 'TH', 'IY1', 'CH', 'T', 'AW1', 'F', 'AH1', 'Z', |
|
'R', 'EH1', 'UW1', 'M', 'B', 'W', 'V', 'DH', 'K', 'IH0', |
|
'AA1', 'IH1', 'S', 'EY1', 'N', 'OW0', 'L' |
|
] |
|
|
|
cmu_phonemes = [ |
|
' ', '#', 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', |
|
'AE2', 'AH', 'AH0', 'AH1', 'AH2', 'AO', 'AO0', 'AO1', 'AO2', |
|
'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', |
|
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', |
|
'ER0', 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', |
|
'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', |
|
'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', |
|
'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', |
|
'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', |
|
'W', 'Y', 'Z', 'ZH' |
|
] |
|
|
|
phoneme_chars = map_phonemes(phonemes) |
|
cmu_phoneme_chars = map_phonemes(cmu_phonemes) |
|
|
|
def __init__( |
|
self, video_path, alignments_dir, |
|
phonemes_dir, file_list, vid_pad, |
|
image_dir, txt_pad, phase, shared_dict=None, |
|
char_map=CharMap.letters, base_dir='', |
|
frame_doubling=False, sample_all_props=False |
|
): |
|
self.base_dir = base_dir |
|
self.sample_all_props = sample_all_props |
|
|
|
self.image_dir = os.path.join(base_dir, image_dir) |
|
self.alignments_dir = os.path.join(base_dir, alignments_dir) |
|
self.phonemes_dir = os.path.join(base_dir, phonemes_dir) |
|
self.frame_doubling = frame_doubling |
|
|
|
if type(file_list) is str: |
|
file_list = os.path.join(base_dir, file_list) |
|
|
|
file_list = open(file_list, 'r').readlines() |
|
|
|
self.shared_dict = shared_dict |
|
self.char_map = char_map |
|
|
|
self.vid_pad = vid_pad |
|
self.txt_pad = txt_pad |
|
self.phase = phase |
|
|
|
self.videos = [ |
|
os.path.join(video_path, line.strip()) |
|
for line in file_list |
|
] |
|
|
|
self.data = [] |
|
for vid in self.videos: |
|
items = vid.split(os.path.sep) |
|
if len(items) < 2: |
|
print('BAD VID ITEM', items) |
|
raise ValueError |
|
|
|
speaker_name, filename = items[-2], items[-1] |
|
self.data.append((vid, speaker_name, filename)) |
|
|
|
def _fetch_anno_path(self, spk, basename): |
|
return self.fetch_anno_path( |
|
spk=spk, basename=basename, char_map=self.char_map |
|
) |
|
|
|
@classmethod |
|
def text_to_phonemes( |
|
cls, text, as_str=True, char_map=CharMap.phonemes |
|
): |
|
sentence_phonemes = [] |
|
|
|
words = text.upper().strip().split(' ') |
|
for word in words: |
|
word_phonemes = pronouncing.phones_for_word(word)[0] |
|
word_phonemes = word_phonemes.split(' ') |
|
sentence_phonemes.extend(word_phonemes) |
|
sentence_phonemes.append(' ') |
|
|
|
if sentence_phonemes[-1] == ' ': |
|
sentence_phonemes = sentence_phonemes[:-1] |
|
|
|
if as_str: |
|
return cls.stringify(sentence_phonemes, char_map=char_map) |
|
else: |
|
return sentence_phonemes |
|
|
|
def fetch_anno_path(self, spk, basename, char_map): |
|
if char_map == CharMap.letters: |
|
align_path_name = os.path.join( |
|
self.alignments_dir, spk, basename + '.align' |
|
) |
|
return align_path_name |
|
elif char_map == CharMap.lsr2_text: |
|
align_path_name = os.path.join( |
|
self.alignments_dir, spk, basename + '.txt' |
|
) |
|
return align_path_name |
|
elif char_map == CharMap.phonemes: |
|
phonemes_path_name = os.path.join( |
|
self.phonemes_dir, spk, basename + '.align' |
|
) |
|
return phonemes_path_name |
|
elif char_map == CharMap.cmu_phonemes: |
|
phonemes_path_name = os.path.join( |
|
self.phonemes_dir, spk, basename + '.txt' |
|
) |
|
return phonemes_path_name |
|
else: |
|
raise NotImplementedError |
|
|
|
def fetch_anno_text(self, spk, basename, char_map: CharMap): |
|
return self.load_anno_text(self.fetch_anno_path( |
|
spk, basename, char_map=char_map |
|
), char_map=char_map) |
|
|
|
def __getitem__(self, idx): |
|
(vid, spk, name) = self.data[idx] |
|
return self.load_sample( |
|
video_name=vid, speaker_name=spk, |
|
filename=name |
|
) |
|
|
|
def load_random_sample(self, char_map=None): |
|
(vid, spk, name) = random.choice(self.data) |
|
return self.load_sample( |
|
video_name=vid, speaker_name=spk, |
|
filename=name, char_map=char_map |
|
) |
|
|
|
def load_sample( |
|
self, video_name, speaker_name, filename, |
|
char_map=None |
|
): |
|
if char_map is None: |
|
char_map = self.char_map |
|
if self.sample_all_props: |
|
char_map = all |
|
|
|
vid = self.load_vid(video_name) |
|
if self.frame_doubling: |
|
vid = np.repeat(vid, repeats=2, axis=0) |
|
|
|
basename, _ = os.path.splitext(filename) |
|
|
|
txt_results, phoneme_results = {}, {} |
|
cmu_phoneme_results = {} |
|
|
|
if (char_map is all) or (char_map == CharMap.letters): |
|
txt_anno, txt_anno_arr = self.fetch_anno_text( |
|
speaker_name, basename, char_map=CharMap.letters |
|
) |
|
|
|
txt_anno_arr_len = txt_anno_arr.shape[0] |
|
txt_anno_arr = self._padding(txt_anno_arr, self.txt_pad) |
|
assert not np.isnan(txt_anno_arr).any() |
|
|
|
txt_anno += [' '] * (options.txt_padding - len(txt_anno)) |
|
txt_results = kwargify( |
|
txt=torch.LongTensor(txt_anno_arr), |
|
txt_len=txt_anno_arr_len, txt_anno=txt_anno |
|
) |
|
|
|
if (char_map is all) or (char_map == CharMap.phonemes): |
|
phoneme_anno, phoneme_anno_arr = self.fetch_anno_text( |
|
speaker_name, basename, char_map=CharMap.phonemes |
|
) |
|
|
|
phoneme_anno_arr_len = phoneme_anno_arr.shape[0] |
|
phoneme_anno_arr = self._padding( |
|
phoneme_anno_arr, self.txt_pad |
|
) |
|
assert not np.isnan(phoneme_anno_arr_len).any() |
|
|
|
phoneme_results = kwargify( |
|
phonemes=torch.LongTensor(phoneme_anno_arr), |
|
phonemes_len=phoneme_anno_arr_len, |
|
) |
|
|
|
elif (char_map is all) or (char_map == CharMap.cmu_phonemes): |
|
phoneme_anno, phoneme_anno_arr = self.fetch_anno_text( |
|
speaker_name, basename, char_map=CharMap.cmu_phonemes |
|
) |
|
|
|
phoneme_anno_arr_len = phoneme_anno_arr.shape[0] |
|
phoneme_anno_arr = self._padding( |
|
phoneme_anno_arr, self.txt_pad |
|
) |
|
assert not np.isnan(phoneme_anno_arr_len).any() |
|
|
|
cmu_phoneme_results = kwargify( |
|
cmu_phonemes=torch.LongTensor(phoneme_anno_arr), |
|
cmu_phonemes_len=phoneme_anno_arr_len, |
|
) |
|
|
|
if self.phase == 'train': |
|
vid = HorizontalFlip(vid) |
|
|
|
vid = ColorNormalize(vid) |
|
vid_len = vid.shape[0] |
|
vid = self._padding(vid, self.vid_pad) |
|
|
|
""" |
|
if vid_len <= anno_len * 2: |
|
raise ValueError(f'CTC INVALID: {self.data[idx]}') |
|
""" |
|
|
|
assert not np.isnan(vid).any() |
|
|
|
return kwargify( |
|
vid=torch.FloatTensor(vid.transpose(3, 0, 1, 2)), |
|
vid_len=vid_len, **txt_results, **phoneme_results, |
|
**cmu_phoneme_results |
|
) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
@staticmethod |
|
def serialize(data: np.ndarray): |
|
return torch.from_numpy(data.astype(np.uint8)) |
|
|
|
@staticmethod |
|
def deserialize(data: torch.Tensor): |
|
return data.numpy().astype(np.float16) |
|
|
|
@staticmethod |
|
def process_vid(video_path: str, to_tensor=True): |
|
frames = Extractor.extract_frames( |
|
video_path, recycle_landmarks=True, use_gpu=True |
|
) |
|
|
|
frames = [f for f in frames if f is not None] |
|
array = list(filter(lambda im: im is not None, frames)) |
|
array = [ |
|
cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) |
|
for im in array |
|
] |
|
|
|
array = np.stack(array, axis=0).astype(np.float16) |
|
vid = ColorNormalize(array) |
|
|
|
if to_tensor: |
|
vid = torch.FloatTensor(vid.transpose(3, 0, 1, 2)) |
|
|
|
return vid |
|
|
|
def load_vid(self, video_path: str) -> np.ndarray: |
|
return self._load_vid(video_path, cache=False) |
|
|
|
def _load_vid(self, video_path: str, cache=True) -> np.ndarray: |
|
if cache and self.shared_dict is not None: |
|
if video_path in self.shared_dict: |
|
return self.deserialize( |
|
self.shared_dict[video_path] |
|
) |
|
|
|
|
|
base_filename = os.path.basename(video_path) |
|
basename, _ = os.path.splitext(base_filename) |
|
speaker_dir = os.path.basename(os.path.dirname(video_path)) |
|
image_dir = f'{self.image_dir}/{speaker_dir}/{basename}' |
|
|
|
files = os.listdir(image_dir) |
|
files = list(filter(lambda file: file.find('.jpg') != -1, files)) |
|
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) |
|
array = [cv2.imread(os.path.join(image_dir, file)) for file in files] |
|
array = list(filter(lambda im: im is not None, array)) |
|
array = [ |
|
cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) |
|
for im in array |
|
] |
|
|
|
try: |
|
array = np.stack(array, axis=0).astype(np.float16) |
|
except ValueError as e: |
|
print(f'BAD VIDEO PATH: {video_path}') |
|
raise e |
|
|
|
if cache and self.shared_dict is not None: |
|
|
|
serialized_data = self.serialize(array) |
|
serialized_data.share_memory_() |
|
self.shared_dict[video_path] = serialized_data |
|
|
|
|
|
return array |
|
|
|
@classmethod |
|
def load_anno(cls, name, char_map): |
|
return cls.load_anno_text(name, char_map)[1] |
|
|
|
@classmethod |
|
def load_anno_text(cls, name, char_map): |
|
|
|
txt = cls.load_sentence(name, char_map=char_map) |
|
indices = cls.txt2arr(txt, 1, char_map=char_map) |
|
|
|
return txt, indices |
|
|
|
def _load_anno(self, name): |
|
return self.load_anno(name, self.char_map) |
|
|
|
@classmethod |
|
def load_sentence(cls, name, char_map=CharMap.letters) -> List[str]: |
|
with open(name, 'r') as f: |
|
if char_map == CharMap.letters: |
|
lines = [line.strip().split(' ') for line in f.readlines()] |
|
txt = [line[2] for line in lines] |
|
txt = list(filter( |
|
lambda s: not s.upper() in ['SIL', 'SP'], txt |
|
)) |
|
|
|
all_chars = list(' '.join(txt)) |
|
all_chars = [char.upper() for char in all_chars] |
|
return all_chars |
|
|
|
elif char_map == CharMap.lsr2_text: |
|
text_line = f.readlines()[0] |
|
text_line = text_line[5:].strip() |
|
all_chars = [char.upper() for char in text_line] |
|
return all_chars |
|
|
|
elif char_map in (CharMap.phonemes, CharMap.cmu_phonemes): |
|
all_chars = [] |
|
|
|
for line in f.readlines(): |
|
word_phonemes = line.strip().split(' ') |
|
all_chars.extend(word_phonemes) |
|
all_chars.append(' ') |
|
|
|
if all_chars[-1] == ' ': |
|
all_chars = all_chars[:-1] |
|
|
|
return all_chars |
|
else: |
|
raise ValueError(f'BAD CHAR MAP {char_map}') |
|
|
|
@classmethod |
|
def load_str_sentence(cls, name, char_map=CharMap.letters) -> str: |
|
chars_seq = cls.load_sentence(name=name, char_map=char_map) |
|
return cls.stringify(chars_seq, char_map=char_map) |
|
|
|
@staticmethod |
|
def tokenize_text(text: str, word_tokenize=False) -> List[str]: |
|
""" |
|
:param text: |
|
:param word_tokenize: |
|
whether to tokenize into words or individual characters |
|
:return: |
|
""" |
|
if word_tokenize: |
|
return text.split(' ') |
|
else: |
|
return list(text) |
|
|
|
@staticmethod |
|
def tokenize_phonemes(text: str, word_tokenize=False) -> List[str]: |
|
""" |
|
:param text: |
|
:param word_tokenize: |
|
whether to tokenize into words or individual phonemes |
|
example: |
|
text = 'S-EH1-T G-R-IY1-N IH0-N EH1-L S-IH1-K-S AH0-G-EH1-N' |
|
word-level tokens: |
|
['S-EH1-T', 'G-R-IY1-N', 'IH0-N', 'EH1-L', 'S-IH1-K-S', 'AH0-G-EH1-N'] |
|
phoneme-level tokens: |
|
['S', 'EH1', 'T', ' ', 'G', 'R', 'IY1', 'N', ' ', 'IH0', |
|
'N', ' ', 'EH1', 'L', ' ', 'S', 'IH1', 'K', 'S', ' ', |
|
'AH0', 'G', 'EH1', 'N'] |
|
:return: |
|
""" |
|
if word_tokenize: |
|
return text.split(' ') |
|
else: |
|
words = text.split(' ') |
|
phonemes = [] |
|
|
|
for word in words: |
|
assert not word.startswith('-') |
|
assert not word.endswith('-') |
|
phonemes.extend(word.split('-')) |
|
phonemes.append(' ') |
|
|
|
if phonemes[-1] == ' ': |
|
phonemes = phonemes[:-1] |
|
|
|
return phonemes |
|
|
|
@staticmethod |
|
def _padding(array, length): |
|
array = [array[_] for _ in range(array.shape[0])] |
|
size = array[0].shape |
|
|
|
for i in range(length - len(array)): |
|
array.append(np.zeros(size)) |
|
|
|
return np.stack(array, axis=0) |
|
|
|
@classmethod |
|
def txt2arr(cls, txt, start, char_map=CharMap.letters): |
|
arr = [] |
|
|
|
if char_map == CharMap.letters: |
|
for char in list(txt): |
|
arr.append(cls.letters.index(char) + start) |
|
|
|
elif char_map == CharMap.phonemes: |
|
|
|
for phoneme in txt: |
|
arr.append(cls.phonemes.index(phoneme) + start) |
|
|
|
elif char_map == CharMap.cmu_phonemes: |
|
|
|
for phoneme in txt: |
|
arr.append(cls.cmu_phonemes.index(phoneme) + start) |
|
|
|
elif char_map == CharMap.visemes: |
|
raise NotImplementedError |
|
else: |
|
raise ValueError(f'BAD CHAR MAP: {char_map}') |
|
|
|
return np.array(arr) |
|
|
|
def arr2txt(self, arr, start, char_map=None): |
|
char_map = self.char_map if char_map is None else char_map |
|
return self._arr2txt(arr, start, char_map=char_map) |
|
|
|
@classmethod |
|
def _arr2txt(cls, arr, start, char_map=CharMap.letters): |
|
txt = [] |
|
|
|
for n in arr: |
|
if n >= start: |
|
if char_map == CharMap.letters: |
|
txt.append(cls.letters[n - start]) |
|
elif char_map == CharMap.phonemes: |
|
txt.append(cls.phonemes[n - start]) |
|
elif char_map == CharMap.cmu_phonemes: |
|
txt.append(cls.cmu_phonemes[n - start]) |
|
elif char_map == CharMap.visemes: |
|
raise NotImplementedError |
|
else: |
|
raise ValueError(f'BAD CHAR MAP: {char_map}') |
|
|
|
return cls.stringify(txt, char_map) |
|
|
|
def get_char_mapping(self): |
|
return self.char_mapping(self.char_map) |
|
|
|
@classmethod |
|
def char_mapping(cls, char_map): |
|
if char_map == CharMap.letters: |
|
return cls.letters |
|
elif char_map == CharMap.phonemes: |
|
return cls.phonemes |
|
elif char_map == CharMap.cmu_phonemes: |
|
return cls.cmu_phonemes |
|
elif char_map == CharMap.visemes: |
|
raise NotImplementedError |
|
else: |
|
raise ValueError(f'BAD CHAR MAP: {char_map}') |
|
|
|
def ctc_decode(self, y): |
|
y = y.argmax(-1) |
|
return [ |
|
self.ctc_arr2txt(y[_], start=1) |
|
for _ in range(y.size(0)) |
|
] |
|
|
|
def ctc_decode_indices(self, y): |
|
y = y.argmax(-1) |
|
return [ |
|
self.ctc_arr2txt_indices(y[_], start=1)[1] |
|
for _ in range(y.size(0)) |
|
] |
|
|
|
def ctc_arr2txt(self, *args, **kwargs): |
|
sentence, indices = self.ctc_arr2txt_pair(*args, **kwargs) |
|
return sentence |
|
|
|
def ctc_arr2txt_pair( |
|
self, arr, start, char_map=None, |
|
filter_previous=True |
|
): |
|
""" |
|
converts token indices into a string sentence |
|
|
|
:param arr: |
|
array of token indices |
|
:param start: |
|
number of special characters in character set |
|
:param char_map: |
|
character set to use for tokenization |
|
:param filter_previous: |
|
if True, removes consecutive occurrences of an index / token |
|
e.g. THREE becomes THRE, SOON becomes SON |
|
:return: |
|
""" |
|
sentence, indices = self.ctc_arr2txt_indices( |
|
arr=arr, start=start, char_map=char_map, |
|
filter_previous=filter_previous |
|
) |
|
return sentence, indices |
|
|
|
def ctc_arr2txt_indices( |
|
self, arr, start, char_map=None, |
|
filter_previous=True |
|
): |
|
""" |
|
converts token indices into a string sentence |
|
and indices of tokens taken along arr |
|
|
|
:param arr: |
|
array of token indices |
|
:param start: |
|
number of special characters in character set |
|
:param char_map: |
|
character set to use for tokenization |
|
:param filter_previous: |
|
if True, removes consecutive occurrences of an index / token |
|
e.g. THREE becomes THRE, SOON becomes SON |
|
:return: |
|
""" |
|
if char_map is None: |
|
char_map = self.char_map |
|
|
|
previous = -1 |
|
txt, indices = [], [] |
|
char_mapping = self.char_mapping(char_map) |
|
|
|
for k, n in enumerate(arr): |
|
check_consecutive = ( |
|
not filter_previous or previous != n |
|
) |
|
if n >= start: |
|
has_empty_char = ( |
|
len(txt) > 0 and txt[-1] == ' ' and |
|
char_mapping[n - start] == ' ' |
|
) |
|
|
|
if not has_empty_char and check_consecutive: |
|
txt.append(char_mapping[n - start]) |
|
indices.append(k) |
|
|
|
previous = n |
|
|
|
sentence = self.stringify(txt, char_map) |
|
return sentence, indices |
|
|
|
@staticmethod |
|
def stringify(txt, char_map): |
|
if char_map in (CharMap.letters, CharMap.lsr2_text): |
|
return ''.join(txt).strip() |
|
elif char_map in (CharMap.phonemes, CharMap.cmu_phonemes): |
|
sentence = '-'.join(txt).strip() |
|
sentence = sentence.replace('- ', ' ') |
|
sentence = sentence.replace(' -', ' ') |
|
if sentence.endswith('-'): sentence = sentence[:-1] |
|
if sentence.startswith('-'): sentence = sentence[1:] |
|
return sentence |
|
else: |
|
raise NotImplementedError |
|
|
|
def _map_chars(self, chars: str): |
|
return self.map_chars(chars, char_map=self.char_map) |
|
|
|
@classmethod |
|
def map_chars(cls, chars: str, char_map: CharMap): |
|
|
|
|
|
|
|
if char_map == CharMap.letters: |
|
return chars |
|
elif char_map in (CharMap.phonemes, CharMap.cmu_phonemes): |
|
if char_map == CharMap.phonemes: |
|
phonemes_arr = cls.phonemes |
|
char_phonemes_arr = cls.phonemes |
|
elif char_map == CharMap.cmu_phonemes: |
|
phonemes_arr = cls.cmu_phonemes |
|
char_phonemes_arr = cls.cmu_phoneme_chars |
|
else: |
|
raise ValueError(f'BAD CHAR MAP {char_map}') |
|
|
|
words = chars.split(' ') |
|
char_phonemes = '' |
|
|
|
for word in words: |
|
phonemes = word.split('-') |
|
phonemes = [ |
|
phoneme for phoneme in phonemes |
|
if phoneme.strip() != '' |
|
] |
|
|
|
for phoneme in phonemes: |
|
char_phonemes += char_phonemes_arr[ |
|
phonemes_arr.index(phoneme) |
|
] |
|
|
|
char_phonemes += ' ' |
|
|
|
return char_phonemes |
|
elif char_map == CharMap.visemes: |
|
raise NotImplementedError |
|
else: |
|
raise ValueError(f'BAD CHAR MAP: {char_map}') |
|
|
|
@classmethod |
|
def map_char_lists( |
|
cls, char_lists: Iterable[str], char_map: CharMap |
|
): |
|
return [cls.map_chars( |
|
char_seq, char_map=char_map |
|
) for char_seq in char_lists] |
|
|
|
def wer(self, raw_predict, raw_truth): |
|
return self.get_wer( |
|
raw_predict, raw_truth, char_map=self.char_map |
|
) |
|
|
|
@classmethod |
|
def get_wer(cls, raw_predict, raw_truth, char_map: CharMap): |
|
assert isinstance(raw_predict, Iterable) |
|
assert isinstance(raw_truth, Iterable) |
|
|
|
predict = cls.map_char_lists(raw_predict, char_map=char_map) |
|
truth = cls.map_char_lists(raw_truth, char_map=char_map) |
|
|
|
|
|
word_pairs = [ |
|
(p[0].split(' '), p[1].split(' ')) |
|
for p in zip(predict, truth) |
|
] |
|
wer = [ |
|
1.0 * editdistance.eval(p[0], p[1])/len(p[1]) |
|
for p in word_pairs |
|
] |
|
return wer |
|
|
|
def cer(self, raw_predict, raw_truth): |
|
return self.get_cer( |
|
raw_predict, raw_truth, char_map=self.char_map |
|
) |
|
|
|
@classmethod |
|
def get_cer(cls, raw_predict, raw_truth, char_map: CharMap): |
|
assert isinstance(raw_predict, Iterable) |
|
assert isinstance(raw_truth, Iterable) |
|
|
|
predict = cls.map_char_lists(raw_predict, char_map=char_map) |
|
truth = cls.map_char_lists(raw_truth, char_map=char_map) |
|
|
|
cer = [ |
|
1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) |
|
for p in zip(predict, truth) |
|
] |
|
return cer |
|
|