|
import Loader |
|
import numpy as np |
|
import os |
|
|
|
from sklearn.model_selection import train_test_split |
|
|
|
TEST_FRAC = 0.2 |
|
RANDOM_SEED = 42 |
|
|
|
loader = Loader.GridLoader() |
|
|
|
video_paths = loader.load_video_paths( |
|
verbose=True, fetch_all_paths=False, |
|
excluded_speakers=[34] |
|
) |
|
|
|
|
|
bad_lip_pairs = { |
|
(1, 'pbio7a'), (1, 'bwwuzn'), (7, 'bbir1s'), (1, 'prii9a'), |
|
(2, 'pbwxzs'), (7, 'bbir2p'), (7, 'lbad1s'), (17, 'lbib9a'), |
|
(1, 'bbizzn'), (7, 'lgws5s'), (3, 'lgbz9s'), (1, 'lrarzn'), |
|
(3, 'pbiu6n'), (1, 'pbwx1s'), (3, 'bgit2n'), (3, 'lbij5s'), |
|
(3, 'bramzn'), (1, 'lgbf8n'), (7, 'lrwe6p'), (1, 'brwg8p'), |
|
(1, 'sbbh4p'), (3, 'swiu2n'), (7, 'lwak8p'), (7, 'sbatzn'), |
|
(2, 'pwbd6s'), (7, 'pwii6n'), (9, 'bwaf6n'), (3, 'pgwy7s'), |
|
(7, 'lwws1a'), (1, 'sran9s'), (7, 'bgam9s'), (3, 'bgbn9a'), |
|
(3, 'prwq3a'), (7, 'sgio2p'), (4, 'lwiy3n'), (3, 'lbij7a'), |
|
(1, 'brwa4p'), (2, 'pbib7p'), (3, 'lrbr3s') |
|
} |
|
|
|
new_video_paths = [] |
|
for video_path in video_paths: |
|
sentence = os.path.basename(video_path) |
|
sentence, _ = os.path.splitext(sentence) |
|
speaker_name = os.path.basename(os.path.dirname(video_path)) |
|
speaker_no = int(speaker_name[1:]) |
|
cache_key = (speaker_no, sentence) |
|
|
|
|
|
if cache_key in bad_lip_pairs: |
|
print('SKIPPING', video_path) |
|
continue |
|
|
|
new_video_paths.append(video_path) |
|
|
|
|
|
video_paths = new_video_paths |
|
train_paths, validate_paths, _, _ = train_test_split( |
|
video_paths, video_paths, |
|
test_size=TEST_FRAC, random_state=RANDOM_SEED |
|
) |
|
|
|
|
|
def get_speakers(filepaths): |
|
return set([os.path.basename(os.path.dirname(x)) for x in filepaths]) |
|
|
|
|
|
train_paths = sorted(train_paths) |
|
validate_paths = sorted(validate_paths) |
|
|
|
print(f'ALL_SPEAKERS {get_speakers(video_paths)}') |
|
print(f'TRAIN_PATHS: {len(train_paths)}') |
|
print(f'TRAIN_SPEAKERS: {get_speakers(train_paths)}') |
|
print(f'VALIDATE_PATHS: {len(validate_paths)}') |
|
print(f'VALIDATE_SPEAKERS: {get_speakers(validate_paths)}') |
|
|
|
open('data/unseen_train.txt', 'w').write('\n'.join(train_paths)) |
|
open('data/unseen_val.txt', 'w').write('\n'.join(validate_paths)) |