File size: 4,259 Bytes
24cfc1b bbc8456 24cfc1b 3ee2d3e 24cfc1b 3ee2d3e 24cfc1b c01c45a bbc8456 c58787f bbc8456 24cfc1b bbc8456 4640262 bbc8456 4640262 bbc8456 24cfc1b bbc8456 3ee2d3e bbc8456 f811062 24cfc1b 3a82e83 24cfc1b 3ee2d3e bbc8456 f035de9 0458fda f035de9 3ee2d3e bbc8456 3ee2d3e bbc8456 3ee2d3e 3a82e83 bbc8456 3ee2d3e bbc8456 3ee2d3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import numpy as np
import imageio
import tensorflow as tf
from tensorflow import keras
from utils import TubeMaskingGenerator
from utils import read_video, frame_sampling, denormalize, reconstrunction
from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
from labels import K400_label_map, SSv2_label_map, UCF_label_map
MODELS = {
'K400': [
'./TFVideoMAE_S_K400_16x224_FT',
'./TFVideoMAE_S_K400_16x224_PT'
],
'SSv2': [
'./TFVideoMAE_S_K400_16x224_FT',
'./TFVideoMAE_S_K400_16x224_PT'
],
'UCF' : [
'./TFVideoMAE_S_K400_16x224_FT',
'./TFVideoMAE_S_K400_16x224_PT'
]
}
LABEL_MAPS = {
'K400': K400_label_map,
'SSv2': SSv2_label_map,
'UCF' : UCF_label_map
}
def tube_mask_generator(mask_ratio):
window_size = (
num_frames // 2,
input_size // patch_size[0],
input_size // patch_size[1]
)
tube_mask = TubeMaskingGenerator(
input_size=window_size,
mask_ratio=mask_ratio
)
make_bool = tube_mask()
bool_masked_pos_tf = tf.constant(make_bool, dtype=tf.int32)
bool_masked_pos_tf = tf.expand_dims(bool_masked_pos_tf, axis=0)
bool_masked_pos_tf = tf.cast(bool_masked_pos_tf, tf.bool)
return bool_masked_pos_tf
def get_model(data_type):
ft_model = keras.models.load_model(MODELS[data_type][0])
pt_model = keras.models.load_model(MODELS[data_type][1])
label_map = LABEL_MAPS.get(data_type)
label_map = K400_label_map
label_map = {v: k for k, v in label_map.items()}
return ft_model, pt_model, label_map
def inference(video_file, data_type, mask_ratio):
print('---------------------------')
print(video_file)
print(data_type)
print(mask_ratio)
print('---------------------------')
# get sample data
container = read_video(video_file)
frames = frame_sampling(container, num_frames=num_frames)
# get models
bool_masked_pos_tf = tube_mask_generator(mask_ratio)
ft_model, pt_model, label_map = get_model(data_type)
ft_model.trainable = False
pt_model.trainable = False
# inference on fine-tune model
outputs_ft = ft_model(frames[None, ...], training=False)
probabilities = tf.nn.softmax(outputs_ft).numpy().squeeze(0)
confidences = {
label_map[i]: float(probabilities[i]) for i in np.argsort(probabilities)[::-1]
}
# inference on pre-trained model
outputs_pt = pt_model(frames[None, ...], bool_masked_pos_tf, training=False)
reconstruct_output, mask = reconstrunction(
frames[None, ...], bool_masked_pos_tf, outputs_pt
)
# post process
input_frame = denormalize(frames)
input_mask = denormalize(mask[0] * frames)
output_frame = denormalize(reconstruct_output)
frames = []
for frame_a, frame_b, frame_c in zip(input_frame, input_mask, output_frame):
combined_frame = np.hstack([frame_a, frame_b, frame_c])
frames.append(combined_frame)
combined_gif = 'combined.gif'
imageio.mimsave(combined_gif, frames, duration=300, loop=0)
return confidences, combined_gif
def main():
datasets = ['K400', 'SSv2', 'UCF']
sample_example = [
["examples/k400.mp4", datasets[0], 0.9],
["examples/ucf.mp4", datasets[1], 0.8],
["examples/k400.mp4", datasets[2], 0.7]
]
iface = gr.Interface(
fn=inference,
inputs=[
gr.Video(type="file", label="Input Video"),
gr.Radio(
datasets,
type='value',
default=datasets[0],
label='Dataset',
),
gr.Slider(
0.5,
1.0,
step=0.1,
default=0.5,
label='Mask Ratio'
)
],
outputs=[
gr.Label(num_top_classes=3, label='scores'),
gr.Image(type="filepath", label='reconstructed')
],
examples=sample_example,
title="VideoMAE",
description="Keras reimplementation of <a href='https://github.com/innat/VideoMAE'>VideoMAE</a> is presented here."
)
iface.launch()
if __name__ == '__main__':
main() |