File size: 4,205 Bytes
24cfc1b 36e374f bbc8456 24cfc1b 9f18b37 3ee2d3e 24cfc1b 3ee2d3e 24cfc1b 36e374f 8ecd497 36e374f bbc8456 24cfc1b 36e374f bbc8456 24cfc1b bbc8456 3ee2d3e 36e374f f811062 24cfc1b 3a82e83 24cfc1b 9ffae4a 24cfc1b 3ee2d3e 36e374f 5bcbb0a bfae168 36e374f 52119aa 3a82e83 bbc8456 3ee2d3e 9ffae4a 3ee2d3e bbc8456 3ee2d3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
import numpy as np
import imageio
import tensorflow as tf
from tensorflow import keras
from utils import TubeMaskingGenerator
from utils import read_video, frame_sampling, denormalize, reconstrunction
from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
from labels import K400_label_map, SSv2_label_map, UCF_label_map
LABEL_MAPS = {
'K400': K400_label_map,
'SSv2': SSv2_label_map,
'UCF' : UCF_label_map
}
ALL_MODELS = [
'TFVideoMAE_L_K400_16x224',
'TFVideoMAE_B_SSv2_16x224',
'TFVideoMAE_B_UCF_16x224',
]
sample_example = [
["examples/k400.mp4", ALL_MODELS[0], 0.9],
["examples/ssv2.avi", ALL_MODELS[1], 0.8],
["examples/ucf.mp4", ALL_MODELS[2], 0.7],
]
def tube_mask_generator(mask_ratio):
window_size = (
num_frames // 2,
input_size // patch_size[0],
input_size // patch_size[1]
)
tube_mask = TubeMaskingGenerator(
input_size=window_size,
mask_ratio=mask_ratio
)
make_bool = tube_mask()
bool_masked_pos_tf = tf.constant(make_bool, dtype=tf.int32)
bool_masked_pos_tf = tf.expand_dims(bool_masked_pos_tf, axis=0)
bool_masked_pos_tf = tf.cast(bool_masked_pos_tf, tf.bool)
return bool_masked_pos_tf
def get_model(model_type):
# ft_path = hf_hub_download(
# repo_id='innat/videomae', filename=model_type + '_FT', repo_type="model"
# )
# pt_path = hf_hub_download(
# repo_id='innat/videomae', filename=model_type + '_PT', repo_type="model"
# )
ft_model = keras.models.load_model(model_type + '_FT')
pt_model = keras.models.load_model(model_type + '_PT')
if 'K400' in model_type:
data_type = 'K400'
elif 'SSv2' in model_type:
data_type = 'SSv2'
else:
data_type = 'UCF'
label_map = LABEL_MAPS.get(data_type)
label_map = {v: k for k, v in label_map.items()}
return ft_model, pt_model, label_map
def inference(video_file, model_type, mask_ratio):
# get sample data
container = read_video(video_file)
frames = frame_sampling(container, num_frames=num_frames)
# get models
bool_masked_pos_tf = tube_mask_generator(mask_ratio)
ft_model, pt_model, label_map = get_model(model_type)
ft_model.trainable = False
pt_model.trainable = False
# inference on fine-tune model
outputs_ft = ft_model(frames[None, ...], training=False)
probabilities = tf.nn.softmax(outputs_ft).numpy().squeeze(0)
confidences = {
label_map[i]: float(probabilities[i]) for i in np.argsort(probabilities)[::-1]
}
# inference on pre-trained model
outputs_pt = pt_model(frames[None, ...], bool_masked_pos_tf, training=False)
reconstruct_output, mask = reconstrunction(
frames[None, ...], bool_masked_pos_tf, outputs_pt
)
# post process
input_frame = denormalize(frames)
input_mask = denormalize(mask[0] * frames)
output_frame = denormalize(reconstruct_output)
frames = []
for frame_a, frame_b, frame_c in zip(input_frame, input_mask, output_frame):
combined_frame = np.hstack([frame_a, frame_b, frame_c])
frames.append(combined_frame)
combined_gif = 'combined.gif'
imageio.mimsave(combined_gif, frames, duration=300, loop=0)
return confidences, combined_gif
def main():
iface = gr.Interface(
fn=inference,
inputs=[
gr.Video(type="file", label="Input Video"),
gr.Dropdown(
choices=ALL_MODELS,
default="TFVideoMAE_L_K400_16x224",
label="Model"
),
gr.Slider(
0.5,
1.0,
step=0.1,
default=0.5,
label='Mask Ratio'
)
],
outputs=[
gr.Label(num_top_classes=3, label='scores'),
gr.Image(type="filepath", label='reconstructed')
],
examples=sample_example,
title="VideoMAE",
description="Keras reimplementation of <a href='https://github.com/innat/VideoMAE'>VideoMAE</a> is presented here."
)
iface.launch()
if __name__ == '__main__':
main() |