kt-test-account's picture
add aug heatmaps
1ed31e5
import streamlit as st
from pathlib import Path
import pandas as pd
import json
import metric
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np
import altair as alt
st.set_page_config(
page_title="Public Leaderboard",
initial_sidebar_state="collapsed",
layout="wide", # This makes the app use the full width of the screen
)
@st.cache_data
def load_results(task):
return pd.read_csv(task).set_index("team")
@st.cache_data
def get_updated_time(file="updated.txt"):
return open(file).read()
@st.cache_data
def get_volume():
tasks = ["task1", "task2","task3", "practice"]
subs = pd.concat(
[pd.read_csv(f"{task}_submissions.csv") for task in tasks], ignore_index=True
)
subs["datetime"] = pd.DatetimeIndex(subs["datetime"])
subs["date"] = subs["datetime"].dt.date
subs = (
subs.groupby(["date", "status_reason"]).size().unstack().fillna(0).reset_index()
)
return subs
split = "public"
def show_leaderboad(results):
cols = [
"generated_accuracy",
"pristine_accuracy",
"balanced_accuracy",
"fail_rate",
"total_time",
]
# st.dataframe(results[f"{split}_score"])
column_config = {
"balanced_accuracy": st.column_config.ProgressColumn(
"Balanced Acc",
format="compact",
min_value=0,
pinned=True,
max_value=1.0,
width="medium",
),
"generated_accuracy": st.column_config.ProgressColumn(
"πŸ€– Acc",
format="compact",
min_value=0,
pinned=True,
max_value=1.0,
width="medium",
),
"pristine_accuracy": st.column_config.ProgressColumn(
"πŸ§‘β€πŸŽ€ Acc",
format="compact",
min_value=0,
pinned=True,
max_value=1.0,
width="medium",
),
"fail_rate": st.column_config.NumberColumn(
"❌ Fail Rate",
format="compact",
width="small",
),
"total_time": st.column_config.NumberColumn(
"πŸ•’ Inference Time (s)",
format="compact",
width="small",
),
}
labels = {"pristine": "πŸ§‘β€πŸŽ€", "generated": "πŸ€–"}
for c in results[f"{split}_score"].columns:
if "accuracy" in c:
continue
if any(p in c for p in ["generated", "pristine"]):
s = c.split("_")
pred = s[0]
source = " ".join(s[1:])
column_config[c] = st.column_config.ProgressColumn(
labels[pred] + " " + source,
help=c,
format="compact",
min_value=0,
max_value=1.0,
)
"#### Summary"
st.dataframe(results[f"{split}_score"].loc[:, cols], column_config=column_config)
"#### Accuracy on πŸ€– Generated by Source"
cols = [
c
for c in results[f"{split}_score"].columns
if "generated" in c and "accuracy" not in c
]
st.dataframe(results[f"{split}_score"].loc[:, cols], column_config=column_config)
"#### Accuracy on πŸ§‘β€πŸŽ€ Pristine by Source"
cols = [
c
for c in results[f"{split}_score"].columns
if "pristine" in c and "accuracy" not in c
]
st.dataframe(results[f"{split}_score"].loc[:, cols], column_config=column_config)
def make_roc(results):
results["FA"] = 1.0 - results["pristine_accuracy"]
chart = (
alt.Chart(results)
.mark_circle()
.encode(
x=alt.X("FA:Q", title="1 - πŸ§‘β€πŸŽ€ Acc"),
y=alt.Y("generated_accuracy:Q", title="πŸ€– Acc"),
color="team:N", # Color by categorical field
size=alt.Size(
"total_time:Q", title="πŸ•’ Inference Time", scale=alt.Scale(rangeMin=100)
), # Size by quantitative field
)
.properties(
width=400, height=300, title="Detection vs False Alarm vs Inference Time"
)
)
diag_line = (
alt.Chart(pd.DataFrame(dict(tpr=[0, 1], fpr=[0, 1])))
.mark_line(color="lightgray", strokeDash=[8, 4])
.encode(x="fpr", y="tpr")
)
return chart + diag_line
def make_acc(results):
# results["FA"] = 1. - results["pristine_accuracy"]
chart = (
alt.Chart(results)
.mark_circle(size=200)
.encode(
x=alt.X("total_time:Q", title="πŸ•’ Inference Time"),
y=alt.Y("balanced_accuracy:Q", title="Balanced Accuracy",scale=alt.Scale(domain=[0.4, 1])),
color="team:N", # Color by categorical field # Size by quantitative field
)
.properties(width=400, height=300, title="Inference Time vs Balanced Accuracy")
)
diag_line = (
alt.Chart(pd.DataFrame(dict(t=[0, results["total_time"].max()], y=[.5, .5])))
.mark_line(color="lightgray", strokeDash=[8, 4])
.encode(x="t", y="y")
)
return chart + diag_line
@st.cache_data
def make_heatmap(results,label = "generated", symbol = "πŸ€–", title = ""):
# Assuming df is your wide-format DataFrame (models as rows, datasets as columns)
df_long = results.set_index("team")
team_order = results.index.tolist()
df_long = df_long.loc[:,[c for c in df_long.columns if c.startswith(label) and "accuracy" not in c]]
df_long.columns = [c.replace(f"{label}_","") for c in df_long.columns]
if "none" in df_long.columns:
df_long = df_long.drop(columns=["none"])
df_long = df_long.reset_index().melt(id_vars='team', var_name='source', value_name='acc')
# df_long.rename(columns={'index': 'source'}, inplace=True)
# df_long
# return
# Base chart for rectangles
base = alt.Chart(df_long).encode(
x=alt.X('source:O', title='Source',axis=alt.Axis(orient='top', labelAngle=-60)),
y=alt.Y('team:O', title='Team',sort=team_order)
)
# Heatmap rectangles
heatmap = base.mark_rect().encode(
color=alt.Color('acc:Q', scale=alt.Scale(scheme='greens'), title = f"{label} Accuracy")
)
# Text labels
text = base.mark_text(baseline='middle',fontSize=16).encode(
text=alt.Text('acc:Q', format='.2f'),
color=alt.condition(
alt.datum.acc < 0.5, # you can tune this for readability
alt.value('black'),
alt.value('white')
)
)
# Combine heatmap and text
chart = (heatmap + text).properties(
width=600,
height=500,
title=title
)
return chart
def get_heatmaps(temp):
h1 = make_heatmap(temp, "generated", symbol = "πŸ€–", title = "Accuracy by πŸ€– geneated source")
h2 = make_heatmap(temp, "pristine", symbol = "πŸ§‘β€πŸŽ€", title = "Accuracy by πŸ§‘β€πŸŽ€ pristine source")
st.altair_chart(h1, use_container_width=True)
st.altair_chart(h2, use_container_width=True)
if temp.columns.str.contains("aug", case=False).any():
h3 = make_heatmap(temp, "aug", symbol="πŸ› οΈ", title = "Accuracy by πŸ› οΈ augmentation method on πŸ€– geneated data only")
st.altair_chart(h3, use_container_width=True)
def make_plots_for_task(task,split,best_only):
# results = load_results(task, best_only=best_only)
results = {f"{split}_score": load_results(f"{task}.csv")}
# results1[f"{split}_score"]
temp = results[f"{split}_score"].reset_index()
t1,t2,t3 = st.tabs(["Tables","Charts","Heatmap"])
with t1:
show_leaderboad(results)
with t2:
st.altair_chart(make_roc(temp) | make_acc(temp), use_container_width=False)
with t3:
get_heatmaps(temp)
split = "public"
updated = get_updated_time()
st.markdown(updated)
# st.markdown("#### Detailed Public Leaderboard")
# st.markdown("[SAFE: Synthetic Audio Forensics Evaluation Challenge](https://stresearch.github.io/SAFE/)")
t1, t2, t3, t4 = st.tabs(["**Task 1**", "**Task 2**", "**Task 3**", "**Submission Volume**"])
with t1:
"Detection of Generated Audio. Audio files are unmodified from the original output from the models or the pristine sources."
make_plots_for_task("task1",split,True)
# results1 = {f"{split}_score": load_results("task1.csv")}
# # results1[f"{split}_score"]
# temp = results1[f"{split}_score"].reset_index()
# st.altair_chart(make_roc(temp) | make_acc(temp), use_container_width=False)
# show_leaderboad(results1)
with t2:
"Detection of Processed Audio. Audio files will be compressed and resampled. Only the geneated files are augmented."
make_plots_for_task("task2",split,True)
# results2 = {f"{split}_score": load_results("task2.csv")}
# temp = results2[f"{split}_score"].reset_index()
# st.altair_chart(make_roc(temp) | make_acc(temp), use_container_width=False)
# show_leaderboad(results2)
with t3:
"Detection of Laundered Audio. Audio files will be laundered to bypass detection. Only the geneated files are laundered,"
make_plots_for_task("task3",split,True)
# results3 = {f"{split}_score": load_results("task3.csv")}
# temp = results3[f"{split}_score"].reset_index()
# st.altair_chart(make_roc(temp) | make_acc(temp), use_container_width=False)
# show_leaderboad(results3)
with t4:
subs = get_volume()
st.bar_chart(subs, x="date", y=["SUCCESS", "FAILED"], stack=True)