EVREAL / app.py
ercanburak's picture
add crop since some input videos have black lines at rightmost column
7ce8504
raw
history blame
9.92 kB
import os
import subprocess
import glob
import streamlit as st
from utils import get_configs, get_display_names, get_path_for_viz, get_video_height, get_text_str
# st.header("EVREAL - Event-based Video Reconstruction Evaluation and Analysis Library")
#
# paper_link = "https://arxiv.org/abs/2305.00434"
# code_link = "https://github.com/ercanburak/EVREAL"
# page_link = "https://ercanburak.github.io/evreal.html"
# instructions_video = "https://www.youtube.com/watch?v="
#
# st.markdown("Paper: " + paper_link, unsafe_allow_html=True)
# st.markdown("Code: " + paper_link, unsafe_allow_html=True)
# st.markdown("Page: " + paper_link, unsafe_allow_html=True)
# st.markdown("Please see this video for instructions on how to use this tool: " + instructions_video, unsafe_allow_html=True)
@st.cache_data(show_spinner="Retrieving results...")
def retrieve_results(selected_dataset, selected_sequence, selected_models, selected_metrics, selected_visualizations):
gt_only_viz = [viz for viz in selected_visualizations if viz['viz_type'] == 'gt_only']
model_only_viz = [viz for viz in selected_visualizations if viz['viz_type'] == 'model_only']
both_viz = [viz for viz in selected_visualizations if viz['viz_type'] == 'both']
recon_viz = {"name": "recon", "display_name": "Reconstruction", "viz_type": "both", "gt_type": "frame"}
ground_truth = {"name": "gt", "display_name": "Ground Truth", "model_id": "groundtruth"}
model_viz = [recon_viz] + both_viz + selected_metrics + model_only_viz
num_model_rows = len(model_viz)
gt_viz = []
if selected_dataset['has_frames']:
gt_viz.append(recon_viz)
gt_viz.extend([viz for viz in both_viz if viz['gt_type'] == 'frame'])
gt_viz.extend([viz for viz in gt_only_viz if viz['gt_type'] == 'frame'])
gt_viz.extend([viz for viz in both_viz if viz['gt_type'] == 'event'])
gt_viz.extend([viz for viz in gt_only_viz if viz['gt_type'] == 'event'])
num_gt_rows = len(gt_viz)
num_rows = max(num_model_rows, num_gt_rows)
# total_videos_needed = len(selected_models) * num_model_rows + num_gt_rows
if len(gt_viz) > 0:
selected_models.append(ground_truth)
padding = 2
font_size = 20
num_cols = len(selected_models)
crop_str = "crop=trunc(iw/2)*2-2:trunc(ih/2)*2,"
pad_str = "pad=ceil(iw/2)*2+{}:ceil(ih/2)*2+{}:{}:{}:white".format(padding*2, padding*2, padding, padding)
num_elements = num_rows * num_cols
# remove previous temp data
files = glob.glob('temp_data/temp_*.mp4')
for f in files:
os.remove(f)
w = selected_dataset["width"]
h = selected_dataset["height"]
input_filter_parts = []
xstack_input_parts = []
layout_parts = []
video_paths = []
row_heights = [""]*num_rows
gt_viz_indices = []
if len(model_viz) > 1:
left_pad = (font_size*0.7)*max([len(viz['display_name']) for viz in model_viz[1:]]) + padding*2
else:
left_pad = 0
for row_idx in range(num_rows):
for col_idx in range(num_cols):
vid_idx = len(video_paths)
# progress_bar.progress(float(vid_idx) / total_videos_needed)
cur_model = selected_models[col_idx]
if cur_model['name'] == "gt":
if row_idx < len(gt_viz):
video_path = get_path_for_viz(base_data_dir, selected_dataset, selected_sequence, cur_model, gt_viz[row_idx])
if not os.path.isfile(video_path):
raise ValueError("Could not find video: " + video_path)
gt_viz_indices.append(vid_idx)
else:
continue
else:
if row_idx < len(model_viz):
video_path = get_path_for_viz(base_data_dir, selected_dataset, selected_sequence, cur_model, model_viz[row_idx])
if not os.path.isfile(video_path):
raise ValueError("Could not find video: " + video_path)
else:
continue
if row_heights[row_idx] == "":
row_heights[row_idx] = "h{}".format(vid_idx)
if row_idx == 0:
pad_height = font_size+padding*2
pad_txt_str = ",pad={}:{}:0:{}:white".format(w+padding*2, h+font_size+padding*4, pad_height)
text_str = get_text_str(pad_height, w, cur_model['display_name'], font_size)
pad_txt_str = pad_txt_str + "," + text_str
elif col_idx == 0:
pad_txt_str = ",pad={}:ih:{}:0:white".format(w + left_pad + padding * 2, left_pad)
if len(model_viz) > row_idx > 0:
text_str = get_text_str("h", left_pad, model_viz[row_idx]['display_name'], font_size)
pad_txt_str = pad_txt_str + "," + text_str
else:
pad_txt_str = ""
input_filter_part = "[{}:v]scale={}:-1,{}{}{}[v{}]".format(vid_idx, w, crop_str, pad_str, pad_txt_str, vid_idx)
input_filter_parts.append(input_filter_part)
xstack_input_part = "[v{}]".format(vid_idx)
xstack_input_parts.append(xstack_input_part)
video_paths.append(video_path)
if row_idx == 0 or col_idx > 0:
layout_w_parts = [str(left_pad)] + ["w{}".format(i) for i in range(col_idx)]
layout_w = "+".join(layout_w_parts)
else:
layout_w = "+".join(["w{}".format(i) for i in range(col_idx)]) if col_idx > 0 else "0"
if cur_model['name'] == "gt":
layout_h = "+".join(["h{}".format(i) for i in gt_viz_indices[:-1]]) if row_idx > 0 else "0"
else:
layout_h = "+".join(row_heights[:row_idx]) if row_idx > 0 else "0"
layout_part = layout_w + "_" + layout_h
layout_parts.append(layout_part)
inputs_str = " ".join(["-i " + video_path for video_path in video_paths])
num_inputs = len(video_paths)
input_scaling_str = ";".join(input_filter_parts)
xstack_input_str = "".join(xstack_input_parts)
layout_str = "|".join(layout_parts)
# opt = "-c:v libx264 -preset veryslow -crf 18 -c:a copy"
opt = ""
# opt_fill = ":fill=black"
opt_fill = ":fill=white"
# opt_fill = ""
if num_inputs > 1:
ffmpeg_command_str = "ffmpeg -y " + inputs_str + " -filter_complex \"" + input_scaling_str + ";" + xstack_input_str + "xstack=inputs=" + str(num_inputs) + ":layout=" + layout_str + opt_fill + "\"" + opt + " output.mp4"
else:
# remove last paranthesis
idx = input_scaling_str.rfind("[")
input_scaling_str = input_scaling_str[:idx]
ffmpeg_command_str = "ffmpeg -y " + inputs_str + " -filter_complex \"" + input_scaling_str + "\"" + opt + " output.mp4"
print(ffmpeg_command_str)
ret = subprocess.call(ffmpeg_command_str, shell=True)
if ret != 0:
return None
video_file = open('output.mp4', 'rb')
video_bytes = video_file.read()
return video_bytes
st.title("Result Analysis Tool")
base_data_dir = "data"
dataset_cfg_path = os.path.join("cfg", "dataset")
model_cfg_path = os.path.join("cfg", "model")
metric_cfg_path = os.path.join("cfg", "metric")
viz_cfg_path = os.path.join("cfg", "viz")
datasets = get_configs(dataset_cfg_path)
models = get_configs(model_cfg_path)
metrics = get_configs(metric_cfg_path)
visualizations = get_configs(viz_cfg_path)
dataset_display_names = get_display_names(datasets)
model_display_names = get_display_names(models)
metric_display_names = get_display_names(metrics)
viz_display_names = get_display_names(visualizations)
assert len(set(dataset_display_names)) == len(dataset_display_names), "Dataset display names are not unique"
assert len(set(model_display_names)) == len(model_display_names), "Model display names are not unique"
assert len(set(metric_display_names)) == len(metric_display_names), "Metric display names are not unique"
assert len(set(viz_display_names)) == len(viz_display_names), "Viz display names are not unique"
col1, col2 = st.columns(2)
with col1:
selected_dataset_name = st.selectbox('Select dataset', options=dataset_display_names)
selected_dataset = [dataset for dataset in datasets if dataset['display_name'] == selected_dataset_name][0]
with col2:
selected_sequence = st.selectbox('Select sequence', options=selected_dataset["sequences"].keys())
selected_model_names = st.multiselect('Select multiple methods to compare', model_display_names)
selected_models = [model for model in models if model['display_name'] in selected_model_names]
disable_metrics = len(selected_models) == 0
if disable_metrics:
tooltip_str = "Select at least one method to enable metric selection"
else:
tooltip_str = ""
usable_metrics = [metric for metric in metrics if metric['no_ref'] == selected_dataset['no_ref']]
usable_metric_display_names = get_display_names(usable_metrics)
selected_metric_names = st.multiselect('Select metrics to display', usable_metric_display_names,
disabled=disable_metrics, help=tooltip_str)
selected_metrics = [metric for metric in usable_metrics if metric['display_name'] in selected_metric_names]
if not selected_dataset['has_frames']:
usable_viz = [viz for viz in visualizations if viz['gt_type'] != 'frame']
else:
usable_viz = visualizations
usable_viz_display_names = get_display_names(usable_viz)
selected_viz = st.multiselect('Select other visualizations to display', usable_viz_display_names)
selected_visualizations = [viz for viz in visualizations if viz['display_name'] in selected_viz]
if not st.button('Get Results'):
st.stop()
video_bytes = retrieve_results(selected_dataset, selected_sequence, selected_models, selected_metrics, selected_visualizations)
if video_bytes is None:
st.error("Error while generating video.")
st.stop()
st.video(video_bytes)