Spaces:
Runtime error
Runtime error
import evaluate | |
import numpy as np | |
import pandas as pd | |
import ast | |
import json | |
import gradio as gr | |
from evaluate.utils import launch_gradio_widget | |
from ece import ECE | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
sns.set_style('white') | |
sns.set_context("paper", font_scale=1) # 2 | |
# plt.rcParams['figure.figsize'] = [10, 7] | |
plt.rcParams['figure.dpi'] = 300 | |
plt.switch_backend('agg') #; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop | |
sliders = [ | |
gr.Slider(0, 100, value=10, label="n_bins"), | |
gr.Slider(0, 100, value=None, label="bin_range", visible=False), #DEV: need to have a double slider | |
gr.Dropdown(choices=["equal-range", "equal-mass"], value="equal-range", label="scheme"), | |
gr.Dropdown(choices=["upper-edge", "center"], value="upper-edge", label="proxy"), | |
gr.Dropdown(choices=[1, 2, np.inf], value=1, label="p"), | |
] | |
slider_defaults = [slider.value for slider in sliders] | |
# example data | |
df = dict() | |
df["predictions"] = [[0.6, 0.2, 0.2], [0, 0.95, 0.05], [0.7, 0.1, 0.2]] | |
df["references"] = [0, 1, 2] | |
component = gr.inputs.Dataframe( | |
headers=["predictions", "references"], col_count=2, datatype="number", type="pandas" | |
) | |
component.value = [ | |
[[0.6, 0.2, 0.2], 0], | |
[[0.7, 0.1, 0.2], 2], | |
[[0, 0.95, 0.05], 1], | |
] | |
sample_data = [[component] + slider_defaults] ##json.dumps(df) | |
metric = ECE() | |
# module = evaluate.load("jordyvl/ece") | |
# launch_gradio_widget(module) | |
""" | |
Switch inputs and compute_fn | |
""" | |
def reliability_plot(results): | |
fig = plt.figure() | |
ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) | |
ax2 = plt.subplot2grid((3, 1), (2, 0)) | |
n_bins = len(results["y_bar"]) | |
bin_range = [ | |
results["y_bar"][0] - results["y_bar"][0], | |
results["y_bar"][-1], | |
] # np.linspace(0, 1, n_bins) | |
# if upper edge then minus binsize; same for center [but half] | |
ranged = np.linspace(bin_range[0], bin_range[1], n_bins) | |
ax1.plot( | |
ranged, | |
ranged, | |
color="darkgreen", | |
ls="dotted", | |
label="Perfect", | |
) | |
# ax1.plot(results["y_bar"], results["y_bar"], color="darkblue", label="Perfect") | |
anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0] | |
bin_freqs = np.zeros(n_bins) | |
bin_freqs[anindices] = results["bin_freq"] | |
ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs) | |
#widths = np.diff(results["y_bar"]) | |
for j, bin in enumerate(results["y_bar"]): | |
perfect = results["y_bar"][j] | |
empirical = results["p_bar"][j] | |
if np.isnan(empirical): | |
continue | |
ax1.bar([perfect], height=[empirical], width=-ranged[j], align="edge", color="lightblue") | |
if perfect == empirical: | |
continue | |
acc_plt = ax2.axvline( | |
x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy" | |
) | |
conf_plt = ax2.axvline( | |
x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence" | |
) | |
ax2.legend(handles=[acc_plt, conf_plt]) | |
#Bin differences | |
ax1.set_ylabel("Conditional Expectation") | |
ax1.set_ylim([-0.05, 1.05]) #respective to bin range | |
ax1.legend(loc="lower right") | |
ax1.set_title("Reliability Diagram") | |
#Bin frequencies | |
ax2.set_xlabel("Confidence") | |
ax2.set_ylabel("Count") | |
ax2.legend(loc="upper left")#, ncol=2 | |
plt.tight_layout() | |
return fig | |
def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p): | |
# DEV: check on invalid datatypes with better warnings | |
if isinstance(data, pd.DataFrame): | |
data.dropna(inplace=True) | |
predictions = [ | |
ast.literal_eval(prediction) if not isinstance(prediction, list) else prediction | |
for prediction in data["predictions"] | |
] | |
references = [reference for reference in data["references"]] | |
results = metric._compute( | |
predictions, | |
references, | |
n_bins=n_bins, | |
# bin_range=None,#not needed | |
scheme=scheme, | |
proxy=proxy, | |
p=p, | |
detail=True, | |
) | |
plot = reliability_plot(results) | |
return results["ECE"], plot #plt.gcf() | |
outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")] | |
iface = gr.Interface( | |
fn=compute_and_plot, | |
inputs=[component] + sliders, | |
outputs=outputs, | |
description=metric.info.description, | |
article=metric.info.citation, | |
# examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs. | |
).launch() | |