|
import streamlit as st |
|
from pathlib import Path |
|
import pandas as pd |
|
import json |
|
import metric |
|
from sklearn.metrics import roc_auc_score, roc_curve |
|
import numpy as np |
|
import altair as alt |
|
|
|
|
|
st.set_page_config( |
|
page_title="Public Leaderboard", |
|
initial_sidebar_state="collapsed", |
|
layout="wide", |
|
) |
|
|
|
|
|
@st.cache_data |
|
def load_results(task): |
|
return pd.read_csv(task).set_index("team") |
|
|
|
|
|
@st.cache_data |
|
def get_updated_time(file="updated.txt"): |
|
return open(file).read() |
|
|
|
|
|
@st.cache_data |
|
def get_volume(): |
|
tasks = ["task1", "task2","task3", "practice"] |
|
subs = pd.concat( |
|
[pd.read_csv(f"{task}_submissions.csv") for task in tasks], ignore_index=True |
|
) |
|
subs["datetime"] = pd.DatetimeIndex(subs["datetime"]) |
|
subs["date"] = subs["datetime"].dt.date |
|
subs = ( |
|
subs.groupby(["date", "status_reason"]).size().unstack().fillna(0).reset_index() |
|
) |
|
|
|
return subs |
|
|
|
|
|
split = "public" |
|
|
|
|
|
def show_leaderboad(results): |
|
cols = [ |
|
"generated_accuracy", |
|
"pristine_accuracy", |
|
"balanced_accuracy", |
|
"fail_rate", |
|
"total_time", |
|
] |
|
|
|
|
|
|
|
column_config = { |
|
"balanced_accuracy": st.column_config.ProgressColumn( |
|
"Balanced Acc", |
|
format="compact", |
|
min_value=0, |
|
pinned=True, |
|
max_value=1.0, |
|
width="medium", |
|
), |
|
"generated_accuracy": st.column_config.ProgressColumn( |
|
"π€ Acc", |
|
format="compact", |
|
min_value=0, |
|
pinned=True, |
|
max_value=1.0, |
|
width="medium", |
|
), |
|
"pristine_accuracy": st.column_config.ProgressColumn( |
|
"π§βπ€ Acc", |
|
format="compact", |
|
min_value=0, |
|
pinned=True, |
|
max_value=1.0, |
|
width="medium", |
|
), |
|
"fail_rate": st.column_config.NumberColumn( |
|
"β Fail Rate", |
|
format="compact", |
|
width="small", |
|
), |
|
"total_time": st.column_config.NumberColumn( |
|
"π Inference Time (s)", |
|
format="compact", |
|
width="small", |
|
), |
|
} |
|
|
|
labels = {"pristine": "π§βπ€", "generated": "π€"} |
|
|
|
for c in results[f"{split}_score"].columns: |
|
if "accuracy" in c: |
|
continue |
|
if any(p in c for p in ["generated", "pristine"]): |
|
s = c.split("_") |
|
pred = s[0] |
|
source = " ".join(s[1:]) |
|
column_config[c] = st.column_config.ProgressColumn( |
|
labels[pred] + " " + source, |
|
help=c, |
|
format="compact", |
|
min_value=0, |
|
max_value=1.0, |
|
) |
|
|
|
"#### Summary" |
|
st.dataframe(results[f"{split}_score"].loc[:, cols], column_config=column_config) |
|
|
|
"#### Accuracy on π€ Generated by Source" |
|
|
|
cols = [ |
|
c |
|
for c in results[f"{split}_score"].columns |
|
if "generated" in c and "accuracy" not in c |
|
] |
|
st.dataframe(results[f"{split}_score"].loc[:, cols], column_config=column_config) |
|
|
|
"#### Accuracy on π§βπ€ Pristine by Source" |
|
|
|
cols = [ |
|
c |
|
for c in results[f"{split}_score"].columns |
|
if "pristine" in c and "accuracy" not in c |
|
] |
|
st.dataframe(results[f"{split}_score"].loc[:, cols], column_config=column_config) |
|
|
|
|
|
def make_roc(results): |
|
results["FA"] = 1.0 - results["pristine_accuracy"] |
|
chart = ( |
|
alt.Chart(results) |
|
.mark_circle() |
|
.encode( |
|
x=alt.X("FA:Q", title="1 - π§βπ€ Acc"), |
|
y=alt.Y("generated_accuracy:Q", title="π€ Acc"), |
|
color="team:N", |
|
size=alt.Size( |
|
"total_time:Q", title="π Inference Time", scale=alt.Scale(rangeMin=100) |
|
), |
|
) |
|
.properties( |
|
width=400, height=300, title="Detection vs False Alarm vs Inference Time" |
|
) |
|
) |
|
|
|
diag_line = ( |
|
alt.Chart(pd.DataFrame(dict(tpr=[0, 1], fpr=[0, 1]))) |
|
.mark_line(color="lightgray", strokeDash=[8, 4]) |
|
.encode(x="fpr", y="tpr") |
|
) |
|
|
|
return chart + diag_line |
|
|
|
|
|
def make_acc(results): |
|
|
|
chart = ( |
|
alt.Chart(results) |
|
.mark_circle(size=200) |
|
.encode( |
|
x=alt.X("total_time:Q", title="π Inference Time"), |
|
y=alt.Y("balanced_accuracy:Q", title="Balanced Accuracy",scale=alt.Scale(domain=[0.4, 1])), |
|
color="team:N", |
|
) |
|
.properties(width=400, height=300, title="Inference Time vs Balanced Accuracy") |
|
) |
|
diag_line = ( |
|
alt.Chart(pd.DataFrame(dict(t=[0, results["total_time"].max()], y=[.5, .5]))) |
|
.mark_line(color="lightgray", strokeDash=[8, 4]) |
|
.encode(x="t", y="y") |
|
) |
|
return chart + diag_line |
|
|
|
|
|
@st.cache_data |
|
def make_heatmap(results,label = "generated", symbol = "π€", title = ""): |
|
|
|
|
|
|
|
df_long = results.set_index("team") |
|
team_order = results.index.tolist() |
|
df_long = df_long.loc[:,[c for c in df_long.columns if c.startswith(label) and "accuracy" not in c]] |
|
df_long.columns = [c.replace(f"{label}_","") for c in df_long.columns] |
|
|
|
if "none" in df_long.columns: |
|
df_long = df_long.drop(columns=["none"]) |
|
|
|
df_long = df_long.reset_index().melt(id_vars='team', var_name='source', value_name='acc') |
|
|
|
|
|
|
|
|
|
|
|
base = alt.Chart(df_long).encode( |
|
x=alt.X('source:O', title='Source',axis=alt.Axis(orient='top', labelAngle=-60)), |
|
y=alt.Y('team:O', title='Team',sort=team_order) |
|
) |
|
|
|
|
|
heatmap = base.mark_rect().encode( |
|
color=alt.Color('acc:Q', scale=alt.Scale(scheme='greens'), title = f"{label} Accuracy") |
|
) |
|
|
|
|
|
text = base.mark_text(baseline='middle',fontSize=16).encode( |
|
text=alt.Text('acc:Q', format='.2f'), |
|
color=alt.condition( |
|
alt.datum.acc < 0.5, |
|
alt.value('black'), |
|
alt.value('white') |
|
) |
|
) |
|
|
|
|
|
chart = (heatmap + text).properties( |
|
width=600, |
|
height=500, |
|
title=title |
|
) |
|
|
|
return chart |
|
|
|
|
|
def get_heatmaps(temp): |
|
h1 = make_heatmap(temp, "generated", symbol = "π€", title = "Accuracy by π€ geneated source") |
|
h2 = make_heatmap(temp, "pristine", symbol = "π§βπ€", title = "Accuracy by π§βπ€ pristine source") |
|
st.altair_chart(h1, use_container_width=True) |
|
st.altair_chart(h2, use_container_width=True) |
|
if temp.columns.str.contains("aug", case=False).any(): |
|
h3 = make_heatmap(temp, "aug", symbol="π οΈ", title = "Accuracy by π οΈ augmentation method on π€ geneated data only") |
|
st.altair_chart(h3, use_container_width=True) |
|
|
|
def make_plots_for_task(task,split,best_only): |
|
|
|
results = {f"{split}_score": load_results(f"{task}.csv")} |
|
|
|
|
|
temp = results[f"{split}_score"].reset_index() |
|
t1,t2,t3 = st.tabs(["Tables","Charts","Heatmap"]) |
|
with t1: |
|
show_leaderboad(results) |
|
with t2: |
|
st.altair_chart(make_roc(temp) | make_acc(temp), use_container_width=False) |
|
with t3: |
|
get_heatmaps(temp) |
|
|
|
|
|
split = "public" |
|
|
|
updated = get_updated_time() |
|
st.markdown(updated) |
|
|
|
|
|
|
|
t1, t2, t3, t4 = st.tabs(["**Task 1**", "**Task 2**", "**Task 3**", "**Submission Volume**"]) |
|
with t1: |
|
"Detection of Generated Audio. Audio files are unmodified from the original output from the models or the pristine sources." |
|
make_plots_for_task("task1",split,True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with t2: |
|
"Detection of Processed Audio. Audio files will be compressed and resampled. Only the geneated files are augmented." |
|
make_plots_for_task("task2",split,True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with t3: |
|
"Detection of Laundered Audio. Audio files will be laundered to bypass detection. Only the geneated files are laundered," |
|
make_plots_for_task("task3",split,True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with t4: |
|
subs = get_volume() |
|
st.bar_chart(subs, x="date", y=["SUCCESS", "FAILED"], stack=True) |
|
|