import evaluate import numpy as np import pandas as pd import ast import json import gradio as gr from evaluate.utils import launch_gradio_widget from ece import ECE import matplotlib.pyplot as plt import seaborn as sns sns.set_style('white') sns.set_context("paper", font_scale=1) # 2 # plt.rcParams['figure.figsize'] = [10, 7] plt.rcParams['figure.dpi'] = 300 plt.switch_backend('agg') #; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop sliders = [ gr.Slider(0, 100, value=10, label="n_bins"), gr.Slider(0, 100, value=None, label="bin_range", visible=False), #DEV: need to have a double slider gr.Dropdown(choices=["equal-range", "equal-mass"], value="equal-range", label="scheme"), gr.Dropdown(choices=["upper-edge", "center"], value="upper-edge", label="proxy"), gr.Dropdown(choices=[1, 2, np.inf], value=1, label="p"), ] slider_defaults = [slider.value for slider in sliders] # example data df = dict() df["predictions"] = [[0.6, 0.2, 0.2], [0, 0.95, 0.05], [0.7, 0.1, 0.2]] df["references"] = [0, 1, 2] component = gr.inputs.Dataframe( headers=["predictions", "references"], col_count=2, datatype="number", type="pandas" ) component.value = [ [[0.6, 0.2, 0.2], 0], [[0.7, 0.1, 0.2], 2], [[0, 0.95, 0.05], 1], ] sample_data = [[component] + slider_defaults] ##json.dumps(df) metric = ECE() # module = evaluate.load("jordyvl/ece") # launch_gradio_widget(module) """ Switch inputs and compute_fn """ def reliability_plot(results): fig = plt.figure() ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) ax2 = plt.subplot2grid((3, 1), (2, 0)) n_bins = len(results["y_bar"]) bin_range = [ results["y_bar"][0] - results["y_bar"][0], results["y_bar"][-1], ] # np.linspace(0, 1, n_bins) # if upper edge then minus binsize; same for center [but half] ranged = np.linspace(bin_range[0], bin_range[1], n_bins) ax1.plot( ranged, ranged, color="darkgreen", ls="dotted", label="Perfect", ) # ax1.plot(results["y_bar"], results["y_bar"], color="darkblue", label="Perfect") anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0] bin_freqs = np.zeros(n_bins) bin_freqs[anindices] = results["bin_freq"] ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs) #widths = np.diff(results["y_bar"]) for j, bin in enumerate(results["y_bar"]): perfect = results["y_bar"][j] empirical = results["p_bar"][j] if np.isnan(empirical): continue ax1.bar([perfect], height=[empirical], width=-ranged[j], align="edge", color="lightblue") if perfect == empirical: continue acc_plt = ax2.axvline( x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy" ) conf_plt = ax2.axvline( x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence" ) ax2.legend(handles=[acc_plt, conf_plt]) #Bin differences ax1.set_ylabel("Conditional Expectation") ax1.set_ylim([-0.05, 1.05]) #respective to bin range ax1.legend(loc="lower right") ax1.set_title("Reliability Diagram") #Bin frequencies ax2.set_xlabel("Confidence") ax2.set_ylabel("Count") ax2.legend(loc="upper left")#, ncol=2 plt.tight_layout() return fig def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p): # DEV: check on invalid datatypes with better warnings if isinstance(data, pd.DataFrame): data.dropna(inplace=True) predictions = [ ast.literal_eval(prediction) if not isinstance(prediction, list) else prediction for prediction in data["predictions"] ] references = [reference for reference in data["references"]] results = metric._compute( predictions, references, n_bins=n_bins, # bin_range=None,#not needed scheme=scheme, proxy=proxy, p=p, detail=True, ) plot = reliability_plot(results) return results["ECE"], plot #plt.gcf() outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")] iface = gr.Interface( fn=compute_and_plot, inputs=[component] + sliders, outputs=outputs, description=metric.info.description, article=metric.info.citation, # examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs. ).launch()