Spaces:
Runtime error
Runtime error
File size: 4,730 Bytes
b6d5990 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
'''Module for loading the fakeavceleb dataset from tfrecord format'''
import numpy as np
import tensorflow as tf
from data.augmentation_utils import create_frame_transforms, create_spec_transforms
FEATURE_DESCRIPTION = {
'video_path': tf.io.FixedLenFeature([], tf.string),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'clip/label/index': tf.io.FixedLenFeature([], tf.int64),
'clip/label/text': tf.io.FixedLenFeature([], tf.string),
'WAVEFORM/feature/floats': tf.io.FixedLenFeature([], tf.string)
}
@tf.function
def _parse_function(example_proto):
#Parse the input `tf.train.Example` proto using the dictionary above.
example = tf.io.parse_single_example(example_proto, FEATURE_DESCRIPTION)
video_path = example['video_path']
video = tf.io.decode_raw(example['image/encoded'], tf.int8)
spectrogram = tf.io.decode_raw(example['WAVEFORM/feature/floats'], tf.float32)
label = example["clip/label/text"]
label_map = example["clip/label/index"]
return video, spectrogram, label_map
@tf.function
def decode_inputs(video, spectrogram, label_map):
'''Decode tensors to arrays with desired shape'''
frame = tf.reshape(video, [10, 3, 256, 256])
frame = frame[0] / 255 #Pick the first frame and normalize it.
# frame = tf.cast(frame, tf.float32)
label_map = tf.expand_dims(label_map, axis = 0)
sample = {'video_reshaped': frame, 'spectrogram': spectrogram, 'label_map': label_map}
return sample
def decode_train_inputs(video, spectrogram, label_map):
#Data augmentation for spectograms
spectrogram_shape = spectrogram.shape
spec_augmented = tf.py_function(aug_spec_fn, [spectrogram], tf.float32)
spec_augmented.set_shape(spectrogram_shape)
frame = tf.reshape(video, [10, 256, 256, 3])
frame = frame[0] #Pick the first frame.
frame = frame / 255 #Normalize tensor.
frame_augmented = tf.py_function(aug_img_fn, [frame], tf.uint8)
# frame_augmented.set_shape(frame_shape)
frame_augmented.set_shape([3, 256, 256])
label_map = tf.expand_dims(label_map, axis = 0)
augmented_sample = {'video_reshaped': frame_augmented, 'spectrogram': spec_augmented, 'label_map': label_map}
return augmented_sample
def aug_img_fn(frame):
frame = frame.numpy().astype(np.uint8)
frame_data = {'image': frame}
aug_frame_data = create_frame_transforms(**frame_data)
aug_img = aug_frame_data['image']
aug_img = aug_img.transpose(2, 0, 1)
return aug_img
def aug_spec_fn(spec):
spec = spec.numpy()
spec_data = {'spec': spec}
aug_spec_data = create_spec_transforms(**spec_data)
aug_spec = aug_spec_data['spec']
return aug_spec
class FakeAVCelebDatasetTrain:
def __init__(self, args):
self.args = args
self.samples = self.load_features_from_tfrec()
def load_features_from_tfrec(self):
'''Loads raw features from a tfrecord file and returns them as raw inputs'''
ds = tf.io.matching_files(self.args.data_dir)
files = tf.random.shuffle(ds)
shards = tf.data.Dataset.from_tensor_slices(files)
dataset = shards.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.map(_parse_function, num_parallel_calls = tf.data.AUTOTUNE)
dataset = dataset.map(decode_train_inputs, num_parallel_calls = tf.data.AUTOTUNE)
dataset = dataset.padded_batch(batch_size = self.args.batch_size)
return dataset
def __len__(self):
self.samples = self.load_features_from_tfrec(self.args.data_dir)
cnt = self.samples.reduce(np.int64(0), lambda x, _: x + 1)
cnt = cnt.numpy()
return cnt
class FakeAVCelebDatasetVal:
def __init__(self, args):
self.args = args
self.samples = self.load_features_from_tfrec()
def load_features_from_tfrec(self):
'''Loads raw features from a tfrecord file and returns them as raw inputs'''
ds = tf.io.matching_files(self.args.data_dir)
files = tf.random.shuffle(ds)
shards = tf.data.Dataset.from_tensor_slices(files)
dataset = shards.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.map(_parse_function, num_parallel_calls = tf.data.AUTOTUNE)
dataset = dataset.map(decode_inputs, num_parallel_calls = tf.data.AUTOTUNE)
dataset = dataset.padded_batch(batch_size = self.args.batch_size)
return dataset
def __len__(self):
self.samples = self.load_features_from_tfrec(self.args.data_dir)
cnt = self.samples.reduce(np.int64(0), lambda x, _: x + 1)
cnt = cnt.numpy()
return cnt |