|
import os |
|
import yaml |
|
import options as opt |
|
|
|
from typing import List, Tuple |
|
from dataset import GridDataset, CharMap, Datasets |
|
from tqdm.auto import tqdm |
|
from helpers import * |
|
|
|
|
|
class GridLoader(object): |
|
def __init__(self, base_dir=''): |
|
self.video_dir = os.path.join(base_dir, opt.video_dir) |
|
self.alignment_dir = os.path.join(base_dir, opt.alignments_dir) |
|
self.phonemes_dir = os.path.join(base_dir, opt.phonemes_dir) |
|
self.images_dir = os.path.join(base_dir, opt.images_dir) |
|
self.usable_video_filepaths = None |
|
|
|
def load_video_paths( |
|
self, verbose=False, blacklist=frozenset({}), |
|
ext='mpg', fetch_all_paths=False, excluded_speakers=None, |
|
verify_phonemes_length=False |
|
) -> List[str]: |
|
""" |
|
:param fetch_all_paths: |
|
:param verbose: |
|
whether to show logs |
|
(currently displays numbers of videos with alignment loaded) |
|
:param blacklist: |
|
set of filepaths to exclude from training |
|
:param ext: video file extension |
|
:param excluded_speakers: |
|
:param verify_phonemes_length: |
|
:return: |
|
""" |
|
if excluded_speakers is None: |
|
excluded_speakers = set() |
|
|
|
assert ext in ('mpg', 'mp4') |
|
usable_video_filepaths = [] |
|
videos_without_alignment = [] |
|
all_video_filepaths = [] |
|
ctc_exclusions = 0 |
|
|
|
for speaker_no in range(1, 35): |
|
speaker_dirname = f's{speaker_no}' |
|
speaker_dir = os.path.join(self.video_dir, speaker_dirname) |
|
if speaker_no in excluded_speakers: |
|
if verbose: |
|
print(f'SKIPPING SPEAKER NO {speaker_no}') |
|
|
|
continue |
|
|
|
if not os.path.exists(speaker_dir): |
|
|
|
continue |
|
|
|
video_filenames = os.listdir(speaker_dir) |
|
|
|
for video_filename in video_filenames: |
|
if not video_filename.endswith(f'.{ext}'): |
|
continue |
|
|
|
|
|
base_name = os.path.splitext(video_filename)[0] |
|
images_dir = os.path.join( |
|
self.images_dir, speaker_dirname, base_name |
|
) |
|
video_path = os.path.join( |
|
self.video_dir, speaker_dirname, f'{base_name}.{ext}' |
|
) |
|
|
|
if video_path in blacklist: |
|
continue |
|
|
|
if verify_phonemes_length: |
|
extractable, ctc_invalid = self.is_phoneme_extractable( |
|
speaker_no, base_name, images_dir=images_dir, |
|
verbose=verbose |
|
) |
|
|
|
if ctc_invalid: |
|
ctc_exclusions += 1 |
|
if not extractable: |
|
continue |
|
|
|
if verbose: |
|
num_usable_videos = len(usable_video_filepaths) |
|
num_unusable_videos = len(videos_without_alignment) |
|
|
|
|
|
print(f'videos with alignment: {num_usable_videos}') |
|
print(f'videos without alignment: {num_unusable_videos}') |
|
print(f'CTC EXCLUSIONS: {ctc_exclusions}') |
|
|
|
self.usable_video_filepaths = usable_video_filepaths |
|
|
|
if fetch_all_paths: |
|
return all_video_filepaths |
|
else: |
|
return usable_video_filepaths |
|
|
|
def is_phoneme_extractable( |
|
self, speaker_no, base_name, images_dir, |
|
verbose=False |
|
) -> Tuple[bool, bool]: |
|
""" |
|
:param speaker_no: |
|
:param base_name: |
|
:param images_dir: |
|
:param verbose: |
|
:return: |
|
two boolean values: |
|
the first whether the video is suitable |
|
to be included in the dataset for phoneme prediction |
|
the second bool determines whether the extracted images |
|
and phonemes length corresponding to the video satisfies |
|
CTC loss constraints (video / input length must be more |
|
than twice the length of phoneme sequence / output) |
|
""" |
|
speaker_dirname = f's{speaker_no}' |
|
phonemes_path = os.path.join( |
|
self.phonemes_dir, speaker_dirname, |
|
f'{base_name}.align' |
|
) |
|
|
|
if not os.path.exists(images_dir): |
|
|
|
|
|
return False, False |
|
|
|
try: |
|
phonemes = GridDataset.load_sentence( |
|
phonemes_path, CharMap.phonemes |
|
) |
|
except FileNotFoundError: |
|
|
|
return False, False |
|
|
|
image_names = [ |
|
filename for filename in os.listdir(images_dir) |
|
if filename.endswith('.jpg') |
|
] |
|
|
|
vid_len = len(image_names) |
|
num_phonemes = len(phonemes) |
|
|
|
if vid_len <= num_phonemes * 2: |
|
""" |
|
if video length is less than number of phonemes |
|
then the CTCLoss will return nan, therefore we |
|
exclude videos that would cause this |
|
""" |
|
if verbose: |
|
print(f'CTC EXCLUDE: {speaker_no, base_name}') |
|
print(images_dir, vid_len, num_phonemes) |
|
|
|
return False, True |
|
|
|
return True, False |
|
|
|
def get_grid_sentence_pairs( |
|
self, excluded_speakers, ext='mpg', verbose=False |
|
) -> List[Tuple[int, str]]: |
|
speaker_sentence_pairs = [] |
|
|
|
for speaker_no in range(1, 35): |
|
speaker_dirname = f's{speaker_no}' |
|
speaker_dir = os.path.join(self.video_dir, speaker_dirname) |
|
|
|
if speaker_no in excluded_speakers: |
|
if verbose: |
|
print(f'SKIPPING SPEAKER NO {speaker_no}') |
|
|
|
continue |
|
|
|
if not os.path.exists(speaker_dir): |
|
|
|
continue |
|
|
|
video_filenames = os.listdir(speaker_dir) |
|
for video_filename in video_filenames: |
|
if not video_filename.endswith(f'.{ext}'): |
|
continue |
|
|
|
|
|
base_name = os.path.splitext(video_filename)[0] |
|
speaker_sentence_pairs.append((speaker_no, base_name)) |
|
|
|
return speaker_sentence_pairs |
|
|
|
def get_lsr2_sentence_pairs(self, ext='mp4') -> List[Tuple[str, str]]: |
|
sentence_pairs = [] |
|
|
|
group_dirnames = os.listdir(self.video_dir) |
|
for group_dirname in group_dirnames: |
|
group_dir = os.path.join(self.video_dir, group_dirname) |
|
|
|
if not os.path.exists(group_dir): |
|
continue |
|
|
|
video_filenames = os.listdir(group_dirname) |
|
for video_filename in video_filenames: |
|
if not video_filename.endswith(f'.{ext}'): |
|
continue |
|
|
|
|
|
base_name = os.path.splitext(video_filename)[0] |
|
sentence_pairs.append((group_dir, base_name)) |
|
|
|
return sentence_pairs |
|
|
|
def load_lsr2_phonemes_text_map( |
|
self, phonemes_char_map: CharMap = CharMap.cmu_phonemes, |
|
text_char_map: CharMap = CharMap.lsr2_text, |
|
ext='mp4', verbose=False, |
|
): |
|
phoneme_map, text_map = {}, {} |
|
assert ext in ('mpg', 'mp4') |
|
unique_words = set() |
|
|
|
sentence_pairs = self.get_lsr2_sentence_pairs(ext=ext) |
|
pbar = tqdm(sentence_pairs) |
|
|
|
for sentence_pair in pbar: |
|
group_dir, base_name = sentence_pair |
|
|
|
phonemes_path = os.path.join( |
|
self.phonemes_dir, group_dir, |
|
f'{base_name}.txt' |
|
) |
|
alignments_path = os.path.join( |
|
self.alignment_dir, group_dir, |
|
f'{base_name}.txt' |
|
) |
|
|
|
try: |
|
phonemes_sentence = GridDataset.load_str_sentence( |
|
phonemes_path, char_map=phonemes_char_map |
|
) |
|
letters_sentence = GridDataset.load_str_sentence( |
|
alignments_path, char_map=text_char_map |
|
) |
|
except FileNotFoundError: |
|
continue |
|
|
|
words = letters_sentence.split(' ') |
|
for word in words: |
|
unique_words.add(word) |
|
|
|
phoneme_map[sentence_pair] = phonemes_sentence |
|
text_map[sentence_pair] = letters_sentence |
|
|
|
|
|
|
|
if verbose: |
|
print('UNIQUE_WORDS', len(unique_words)) |
|
|
|
phonemes_text_map = { |
|
phonemes_char_map: phoneme_map, |
|
text_char_map: text_map |
|
} |
|
return phonemes_text_map |
|
|
|
def load_grid_phonemes_text_map( |
|
self, phonemes_char_map: CharMap = CharMap.phonemes, |
|
text_char_map: CharMap = CharMap.letters, |
|
excluded_speakers=None, verbose=False, ext='mpg' |
|
): |
|
if excluded_speakers is None: |
|
excluded_speakers = set() |
|
|
|
phoneme_map, text_map = {}, {} |
|
assert ext in ('mpg', 'mp4') |
|
unique_words = set() |
|
|
|
speaker_sentence_pairs = self.get_grid_sentence_pairs( |
|
ext=ext, excluded_speakers=excluded_speakers, |
|
verbose=verbose |
|
) |
|
|
|
pbar = tqdm(speaker_sentence_pairs) |
|
for speaker_sentence_pair in pbar: |
|
speaker_no, base_name = speaker_sentence_pair |
|
speaker_dirname = f's{speaker_no}' |
|
|
|
phonemes_path = os.path.join( |
|
self.phonemes_dir, speaker_dirname, |
|
f'{base_name}.align' |
|
) |
|
alignments_path = os.path.join( |
|
self.alignment_dir, speaker_dirname, |
|
f'{base_name}.align' |
|
) |
|
|
|
try: |
|
phonemes_sentence = GridDataset.load_str_sentence( |
|
phonemes_path, char_map=phonemes_char_map |
|
) |
|
letters_sentence = GridDataset.load_str_sentence( |
|
alignments_path, char_map=text_char_map |
|
) |
|
except FileNotFoundError: |
|
continue |
|
|
|
words = letters_sentence.split(' ') |
|
for word in words: |
|
unique_words.add(word) |
|
|
|
phoneme_map[speaker_sentence_pair] = phonemes_sentence |
|
text_map[speaker_sentence_pair] = letters_sentence |
|
|
|
|
|
|
|
if verbose: |
|
print('UNIQUE_WORDS', len(unique_words)) |
|
|
|
phonemes_text_map = { |
|
phonemes_char_map: phoneme_map, |
|
text_char_map: text_map |
|
} |
|
return phonemes_text_map |
|
|
|
|
|
if __name__ == '__main__': |
|
loader = GridLoader() |
|
loader.load_video_paths(True) |