Spaces:
Runtime error
Runtime error
File size: 4,423 Bytes
d22cc84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
import numpy as np
import zipfile
import imageio
import tensorflow as tf
from tensorflow import keras
from utils import TubeMaskingGenerator
from utils import read_video, frame_sampling, denormalize, reconstrunction
from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
from labels import K400_label_map, SSv2_label_map, UCF_label_map
LABEL_MAPS = {
'K400': K400_label_map,
'SSv2': SSv2_label_map,
'UCF' : UCF_label_map
}
ALL_MODELS = [
'TFVideoMAE_L_K400_16x224',
'TFVideoMAE_B_SSv2_16x224',
'TFVideoMAE_B_UCF_16x224',
]
sample_example = [
["examples/k400.mp4", ALL_MODELS[0], 0.9],
["examples/ssv2.mp4", ALL_MODELS[1], 0.8],
["examples/ucf.mp4", ALL_MODELS[2], 0.7],
]
def tube_mask_generator(mask_ratio):
window_size = (
num_frames // 2,
input_size // patch_size[0],
input_size // patch_size[1]
)
tube_mask = TubeMaskingGenerator(
input_size=window_size,
mask_ratio=mask_ratio
)
make_bool = tube_mask()
bool_masked_pos_tf = tf.constant(make_bool, dtype=tf.int32)
bool_masked_pos_tf = tf.expand_dims(bool_masked_pos_tf, axis=0)
bool_masked_pos_tf = tf.cast(bool_masked_pos_tf, tf.bool)
return bool_masked_pos_tf
def get_model(model_type):
ft_path = keras.utils.get_file(
origin=f'https://github.com/innat/VideoMAE/releases/download/v1.1/{model_type}_FT.zip',
)
pt_path = keras.utils.get_file(
origin=f'https://github.com/innat/VideoMAE/releases/download/v1.1/{model_type}_PT.zip',
)
with zipfile.ZipFile(ft_path, 'r') as zip_ref:
zip_ref.extractall('./')
with zipfile.ZipFile(pt_path, 'r') as zip_ref:
zip_ref.extractall('./')
ft_model = keras.models.load_model(model_type + '_FT')
pt_model = keras.models.load_model(model_type + '_PT')
if 'K400' in model_type:
data_type = 'K400'
elif 'SSv2' in model_type:
data_type = 'SSv2'
else:
data_type = 'UCF'
label_map = LABEL_MAPS.get(data_type)
label_map = {v: k for k, v in label_map.items()}
return ft_model, pt_model, label_map
def inference(video_file, model_type, mask_ratio):
# get sample data
container = read_video(video_file)
frames = frame_sampling(container, num_frames=num_frames)
# get models
bool_masked_pos_tf = tube_mask_generator(mask_ratio)
ft_model, pt_model, label_map = get_model(model_type)
ft_model.trainable = False
pt_model.trainable = False
# inference on fine-tune model
outputs_ft = ft_model(frames[None, ...], training=False)
probabilities = tf.nn.softmax(outputs_ft).numpy().squeeze(0)
confidences = {
label_map[i]: float(probabilities[i]) for i in np.argsort(probabilities)[::-1]
}
# inference on pre-trained model
outputs_pt = pt_model(frames[None, ...], bool_masked_pos_tf, training=False)
reconstruct_output, mask = reconstrunction(
frames[None, ...], bool_masked_pos_tf, outputs_pt
)
# post process
input_frame = denormalize(frames)
input_mask = denormalize(mask[0] * frames)
output_frame = denormalize(reconstruct_output)
frames = []
for frame_a, frame_b, frame_c in zip(input_frame, input_mask, output_frame):
combined_frame = np.hstack([frame_a, frame_b, frame_c])
frames.append(combined_frame)
combined_gif = 'combined.gif'
imageio.mimsave(combined_gif, frames, duration=300, loop=0)
return confidences, combined_gif
def main():
iface = gr.Interface(
fn=inference,
inputs=[
gr.Video(type="file", label="Input Video"),
gr.Dropdown(
choices=ALL_MODELS,
default="TFVideoMAE_L_K400_16x224",
label="Model"
),
gr.Slider(
0.5,
1.0,
step=0.1,
default=0.5,
label='Mask Ratio'
)
],
outputs=[
gr.Label(num_top_classes=3, label='scores'),
gr.Image(type="filepath", label='reconstructed')
],
examples=sample_example,
title="VideoMAE",
description="Keras reimplementation of <a href='https://github.com/innat/VideoMAE'>VideoMAE</a> is presented here."
)
iface.launch()
if __name__ == '__main__':
main() |