Update app.py
Browse files
app.py
CHANGED
@@ -56,6 +56,7 @@ def inference(video_file, dataset_type, mask_ratio):
|
|
56 |
frames[None, ...], bool_masked_pos_tf, outputs_pt
|
57 |
)
|
58 |
|
|
|
59 |
input_frame = denormalize(frames)
|
60 |
input_mask = denormalize(mask[0] * frames)
|
61 |
output_frame = denormalize(reconstruct_output)
|
@@ -81,11 +82,10 @@ def main():
|
|
81 |
'./TFVideoMAE_S_K400_16x224_PT'
|
82 |
],
|
83 |
'UCF' : [
|
84 |
-
'
|
85 |
'./TFVideoMAE_S_K400_16x224_PT'
|
86 |
]
|
87 |
}
|
88 |
-
|
89 |
BENCHMARK_DATASETS = ['K400', 'SSv2', 'UCF']
|
90 |
SAMPLE_EXAMPLES = [
|
91 |
["examples/k400.mp4", 'Kintetics-400'],
|
@@ -103,11 +103,11 @@ def main():
|
|
103 |
default=BENCHMARK_DATASETS[0],
|
104 |
label='Dataset',
|
105 |
),
|
106 |
-
gr.
|
107 |
-
0
|
108 |
-
1
|
109 |
-
step=0.
|
110 |
-
default=0.
|
111 |
label='Mask Ratio'
|
112 |
)
|
113 |
],
|
|
|
56 |
frames[None, ...], bool_masked_pos_tf, outputs_pt
|
57 |
)
|
58 |
|
59 |
+
# post process
|
60 |
input_frame = denormalize(frames)
|
61 |
input_mask = denormalize(mask[0] * frames)
|
62 |
output_frame = denormalize(reconstruct_output)
|
|
|
82 |
'./TFVideoMAE_S_K400_16x224_PT'
|
83 |
],
|
84 |
'UCF' : [
|
85 |
+
'./TFVideoMAE_S_K400_16x224_FT',
|
86 |
'./TFVideoMAE_S_K400_16x224_PT'
|
87 |
]
|
88 |
}
|
|
|
89 |
BENCHMARK_DATASETS = ['K400', 'SSv2', 'UCF']
|
90 |
SAMPLE_EXAMPLES = [
|
91 |
["examples/k400.mp4", 'Kintetics-400'],
|
|
|
103 |
default=BENCHMARK_DATASETS[0],
|
104 |
label='Dataset',
|
105 |
),
|
106 |
+
gr.Slider(
|
107 |
+
0,
|
108 |
+
1,
|
109 |
+
step=0.05,
|
110 |
+
default=0.5,
|
111 |
label='Mask Ratio'
|
112 |
)
|
113 |
],
|