EVREAL / app.py
ercanburak's picture
Merge branch 'main' of https://huggingface.co/spaces/ercanburak/EVREAL
330bc14
import os
import subprocess
import streamlit as st
from utils import get_configs, get_display_names, get_path_for_viz, get_text_str, get_meta_path
query_params = st.experimental_get_query_params()
disable_header = "header" in query_params and query_params["header"][0] == "false"
if not disable_header:
st.title("EVREAL - Event-based Video Reconstruction Evaluation and Analysis Library")
paper_link = "https://arxiv.org/abs/2305.00434"
code_link = "https://github.com/ercanburak/EVREAL"
page_link = "https://ercanburak.github.io/evreal.html"
st.markdown("**Paper**: " + paper_link, unsafe_allow_html=True)
st.markdown("**Code**: " + code_link, unsafe_allow_html=True)
st.markdown("**Page**: " + page_link, unsafe_allow_html=True)
dummy_string_to_make_huggingface_happy = "ercanburak/evreal_model"
@st.cache_data(show_spinner="Retrieving results...")
def retrieve_results(selected_dataset, selected_sequence, selected_models, selected_metrics, selected_visualizations):
meta_enabled = False
meta_path = get_meta_path(base_data_dir, selected_dataset, selected_sequence)
if meta_enabled and not os.path.isfile(meta_path):
raise ValueError("Meta file not found: {}".format(meta_path))
gt_only_viz = [viz for viz in selected_visualizations if viz['viz_type'] == 'gt_only']
model_only_viz = [viz for viz in selected_visualizations if viz['viz_type'] == 'model_only']
both_viz = [viz for viz in selected_visualizations if viz['viz_type'] == 'both']
recon_viz = {"name": "recon", "display_name": "Reconstruction", "viz_type": "both", "gt_type": "frame"}
ground_truth = {"name": "gt", "display_name": "Ground Truth", "model_id": "groundtruth"}
model_viz = [recon_viz] + both_viz + selected_metrics + model_only_viz
num_model_rows = len(model_viz)
gt_viz = []
if selected_dataset['has_frames']:
gt_viz.append(recon_viz)
gt_viz.extend([viz for viz in both_viz if viz['gt_type'] == 'frame'])
gt_viz.extend([viz for viz in gt_only_viz if viz['gt_type'] == 'frame'])
gt_viz.extend([viz for viz in both_viz if viz['gt_type'] == 'event'])
gt_viz.extend([viz for viz in gt_only_viz if viz['gt_type'] == 'event'])
num_gt_rows = len(gt_viz)
num_rows = max(num_model_rows, num_gt_rows)
if len(gt_viz) > 0:
selected_models.append(ground_truth)
padding = 2
font_size = 20
meta_width = 250
meta_height = 70
num_cols = len(selected_models)
crop_str = "crop=trunc(iw/2)*2-2:trunc(ih/2)*2,"
pad_str = "pad=ceil(iw/2)*2+{}:ceil(ih/2)*2+{}:{}:{}:white".format(padding*2, padding*2, padding, padding)
w = selected_dataset["width"]
h = selected_dataset["height"]
font_size_scale = w / 240.0
font_size = int(font_size * font_size_scale)
input_filter_parts = []
xstack_input_parts = []
layout_parts = []
video_paths = []
row_heights = [""]*num_rows
gt_viz_indices = []
if len(model_viz) > 1:
left_pad = int(font_size*0.8) * max([len(viz['display_name']) for viz in model_viz[1:]]) + padding*2
else:
left_pad = 0
if meta_enabled: # add meta video
if left_pad < meta_width:
left_pad = meta_width
video_paths.append(meta_path)
xstack_input_parts.append("[0:v]")
meta_h_offset = (h - meta_height) / 2
layout_parts.append("0_{}".format(meta_h_offset))
for row_idx in range(num_rows):
for col_idx in range(num_cols):
vid_idx = len(video_paths)
cur_model = selected_models[col_idx]
if cur_model['name'] == "gt":
if row_idx < len(gt_viz):
video_path = get_path_for_viz(base_data_dir, selected_dataset, selected_sequence, cur_model, gt_viz[row_idx])
if not os.path.isfile(video_path):
raise ValueError("Could not find video: " + video_path)
gt_viz_indices.append(vid_idx)
else:
continue
else:
if row_idx < len(model_viz):
video_path = get_path_for_viz(base_data_dir, selected_dataset, selected_sequence, cur_model, model_viz[row_idx])
if not os.path.isfile(video_path):
raise ValueError("Could not find video: " + video_path)
else:
continue
if row_heights[row_idx] == "":
row_heights[row_idx] = "h{}".format(vid_idx)
if row_idx == 0:
pad_height = font_size+padding*2
pad_txt_str = ",pad={}:{}:0:{}:white".format(w+padding*2, h+font_size+padding*4, pad_height)
text_str = get_text_str(pad_height, w, cur_model['display_name'], font_size)
pad_txt_str = pad_txt_str + "," + text_str
elif col_idx == 0:
pad_txt_str = ",pad={}:ih:{}:0:white".format(w + left_pad + padding * 2, left_pad)
if len(model_viz) > row_idx > 0:
text_str = get_text_str("h", left_pad, model_viz[row_idx]['display_name'], font_size)
pad_txt_str = pad_txt_str + "," + text_str
else:
pad_txt_str = ""
input_filter_part = "[{}:v]scale={}:-1,{}{}{}[v{}]".format(vid_idx, w, crop_str, pad_str, pad_txt_str, vid_idx)
input_filter_parts.append(input_filter_part)
xstack_input_part = "[v{}]".format(vid_idx)
xstack_input_parts.append(xstack_input_part)
video_paths.append(video_path)
if row_idx == 0 or col_idx > 0:
layout_w_parts = [str(left_pad)] + ["w{}".format(i) for i in range(col_idx)]
layout_w = "+".join(layout_w_parts)
else:
layout_w = "+".join(["w{}".format(i) for i in range(col_idx)]) if col_idx > 0 else "0"
if cur_model['name'] == "gt":
layout_h = "+".join(["h{}".format(i) for i in gt_viz_indices[:-1]]) if row_idx > 0 else "0"
else:
layout_h = "+".join(row_heights[:row_idx]) if row_idx > 0 else "0"
layout_part = layout_w + "_" + layout_h
layout_parts.append(layout_part)
inputs_str = " ".join(["-i " + video_path for video_path in video_paths])
num_inputs = len(video_paths)
input_scaling_str = ";".join(input_filter_parts)
xstack_input_str = "".join(xstack_input_parts)
layout_str = "|".join(layout_parts)
# opt = "-c:v libx264 -preset veryslow -crf 18 -c:a copy"
opt = ""
# opt_fill = ":fill=black"
opt_fill = ":fill=white"
# opt_fill = ""
if num_inputs > 1:
ffmpeg_command_str = "ffmpeg -y " + inputs_str + " -filter_complex \"" + input_scaling_str + ";" + xstack_input_str + "xstack=inputs=" + str(num_inputs) + ":layout=" + layout_str + opt_fill + "\"" + opt + " output.mp4"
else:
# remove last paranthesis
idx = input_scaling_str.rfind("[")
input_scaling_str = input_scaling_str[:idx]
ffmpeg_command_str = "ffmpeg -y " + inputs_str + " -filter_complex \"" + input_scaling_str + "\"" + opt + " output.mp4"
print(ffmpeg_command_str)
ret = subprocess.call(ffmpeg_command_str, shell=True)
if ret != 0:
return None
video_file = open('output.mp4', 'rb')
video_bytes = video_file.read()
return video_bytes
def display_citation():
citation_string = \
"""
```
@inproceedings{ercan2023evreal,
title={{EVREAL}: Towards a Comprehensive Benchmark and Analysis Suite for Event-based Video Reconstruction},
author={Ercan, Burak and Eker, Onur and Erdem, Aykut and Erdem, Erkut},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month={June},
year={2023},
pages={3942-3951}}
```
"""
st.markdown("## Citation")
st.markdown("If you find this tool useful, please cite the following paper:")
st.markdown(citation_string)
def display_acknowledgements():
st.markdown("## Acknowledgements")
st.markdown("This work was supported in part by KUIS AI Center Research Award, TUBITAK-1001 Program Award No. 121E454, and BAGEP 2021 Award of the Science Academy to A. Erdem.")
def display_footer():
st.markdown("## Contact")
st.markdown("For questions and comments, please contact [Burak Ercan](mailto:[email protected]).")
if not disable_header:
st.header("Result Analysis Tool")
base_data_dir = "data"
dataset_cfg_path = os.path.join("cfg", "dataset")
model_cfg_path = os.path.join("cfg", "model")
metric_cfg_path = os.path.join("cfg", "metric")
viz_cfg_path = os.path.join("cfg", "viz")
datasets = get_configs(dataset_cfg_path)
models = get_configs(model_cfg_path)
metrics = get_configs(metric_cfg_path)
visualizations = get_configs(viz_cfg_path)
dataset_display_names = get_display_names(datasets)
model_display_names = get_display_names(models)
metric_display_names = get_display_names(metrics)
viz_display_names = get_display_names(visualizations)
assert len(set(dataset_display_names)) == len(dataset_display_names), "Dataset display names are not unique"
assert len(set(model_display_names)) == len(model_display_names), "Model display names are not unique"
assert len(set(metric_display_names)) == len(metric_display_names), "Metric display names are not unique"
assert len(set(viz_display_names)) == len(viz_display_names), "Viz display names are not unique"
col1, col2 = st.columns(2)
default_dataset = "ECD"
default_sequence = "dynamic_6dof"
with col1:
default_dataset_index = dataset_display_names.index(default_dataset) if default_dataset in dataset_display_names else 0
selected_dataset_name = st.selectbox('Select dataset:', options=dataset_display_names, index=default_dataset_index)
selected_dataset = [dataset for dataset in datasets if dataset['display_name'] == selected_dataset_name][0]
with col2:
dataset_sequences = list(selected_dataset["sequences"].keys())
default_sequence_index = dataset_sequences.index(default_sequence) if default_sequence in dataset_sequences else 0
selected_sequence = st.selectbox('Select sequence:', options=dataset_sequences, index=default_sequence_index)
selected_model_names = st.multiselect('Select methods to compare:', model_display_names)
selected_models = [models[model_display_names.index(model_name)] for model_name in selected_model_names]
disable_metrics = len(selected_models) == 0
if disable_metrics:
tooltip_str = "Select at least one method to enable metric selection"
else:
tooltip_str = ""
usable_metrics = [metric for metric in metrics if metric['no_ref'] == selected_dataset['no_ref']]
usable_metric_display_names = get_display_names(usable_metrics)
selected_metric_names = st.multiselect('Select quantitative metrics to display:', usable_metric_display_names,
disabled=disable_metrics, help=tooltip_str)
selected_metrics = [metrics[metric_display_names.index(metric_name)] for metric_name in selected_metric_names]
if not selected_dataset['has_frames']:
usable_viz = [viz for viz in visualizations if viz['gt_type'] != 'frame']
else:
usable_viz = visualizations
usable_viz_display_names = get_display_names(usable_viz)
selected_viz = st.multiselect('Select other visualizations to display:', usable_viz_display_names)
selected_visualizations = [visualizations[viz_display_names.index(viz_name)] for viz_name in selected_viz]
if not st.button('Get Results'):
if not disable_header:
display_citation()
display_acknowledgements()
display_footer()
st.stop()
video_bytes = retrieve_results(selected_dataset, selected_sequence, selected_models, selected_metrics, selected_visualizations)
if video_bytes is None:
st.error("Error while generating video.")
st.stop()
st.video(video_bytes)
if len(selected_metrics) > 0:
st.write("Note: For the selected metrics, the instantaneous values are indicated to the upper right of each subplot, "
"whereas the average value over the sequence is indicated in parenthesis next to it.")
if not disable_header:
display_citation()
display_acknowledgements()
display_footer()