import evaluate import json import sys from pathlib import Path import gradio as gr import numpy as np import pandas as pd import ast # from ece import ECE # loads local instead import matplotlib.pyplot as plt import matplotlib.patches as mpatches """ import seaborn as sns sns.set_style('white') sns.set_context("paper", font_scale=1) """ # plt.rcParams['figure.figsize'] = [10, 7] plt.rcParams["figure.dpi"] = 300 plt.switch_backend( "agg" ) # ; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop sliders = [ gr.Slider(0, 100, value=10, label="n_bins"), gr.Slider( 0, 100, value=None, label="bin_range", visible=False ), # DEV: need to have a double slider gr.Dropdown(choices=["equal-range", "equal-mass"], value="equal-range", label="scheme"), gr.Dropdown(choices=["upper-edge", "center"], value="upper-edge", label="proxy"), gr.Dropdown(choices=[1, 2, np.inf], value=1, label="p"), ] slider_defaults = [slider.value for slider in sliders] # example data df = dict() df["predictions"] = [[0.6, 0.2, 0.2], [0, 0.95, 0.05], [0.7, 0.1, 0.2]] df["references"] = [0, 1, 2] component = gr.inputs.Dataframe( headers=["predictions", "references"], col_count=2, datatype="number", type="pandas" ) component.value = [ [[0.6, 0.2, 0.2], 0], [[0.7, 0.1, 0.2], 2], [[0, 0.95, 0.05], 1], ] sample_data = [[component] + slider_defaults] ##json.dumps(df) local_path = Path(sys.path[0]) metric = evaluate.load("jordyvl/ece") # ECE() # module = evaluate.load("jordyvl/ece") # launch_gradio_widget(module) """l Switch inputs and compute_fn """ def default_plot(): fig = plt.figure() ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) ax2 = plt.subplot2grid((3, 1), (2, 0)) ranged = np.linspace(0, 1, 10) ax1.plot( ranged, ranged, color="darkgreen", ls="dotted", label="Perfect", ) # Bin differences ax1.set_ylabel("Conditional Expectation") ax1.set_ylim([0, 1.05]) # respective to bin range ax1.set_title("Reliability Diagram") ax1.set_xlim([-0.05, 1.05]) # respective to bin range # Bin frequencies ax2.set_xlabel("Confidence") ax2.set_ylabel("Count") ax2.legend(loc="upper left") # , ncol=2 ax2.set_xlim([-0.05, 1.05]) # respective to bin range return fig, ax1, ax2 def reliability_plot(results): # DEV: might still need to write tests in case of equal mass binning # DEV: nicer would be to plot like a polygon # see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py def over_under_confidence(results): colors = [] for j, bin in enumerate(results["y_bar"]): perfect = results["y_bar"][j] empirical = results["p_bar"][j] bin_color = ( "limegreen" if np.allclose(perfect, empirical) else "dodgerblue" if empirical < perfect else "orangered" ) colors.append(bin_color) return colors fig, ax1, ax2 = default_plot() # Bin differences bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0) B, bins, patches = ax1.hist( results["y_bar"], weights=np.nan_to_num(results["p_bar"][:-1], copy=True, nan=0), bins=bins_with_left_edge, ) colors = over_under_confidence(results) for b in range(len(B)): patches[b].set_facecolor(colors[b]) # color based on over/underconfidence ax1handles = [ mpatches.Patch(color="orangered", label="Overconfident"), mpatches.Patch(color="limegreen", label="Perfect", linestyle="dotted"), mpatches.Patch(color="dodgerblue", label="Underconfident"), ] # Bin frequencies anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0] n_bins = len(results["y_bar"]) bin_freqs = np.zeros(n_bins) bin_freqs[anindices] = results["bin_freq"] B, newbins, patches = ax2.hist( results["y_bar"], weights=bin_freqs, color="midnightblue", bins=bins_with_left_edge ) acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy") conf_plt = ax2.axvline( x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence" ) ax1.legend(loc="lower right", handles=ax1handles) ax2.legend(handles=[acc_plt, conf_plt]) ax1.set_xticks(bins_with_left_edge) ax2.set_xticks(bins_with_left_edge) plt.tight_layout() return fig def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p): # DEV: check on invalid datatypes with better warnings if isinstance(data, pd.DataFrame): data.dropna(inplace=True) predictions = [ ast.literal_eval(prediction) if not isinstance(prediction, list) else prediction for prediction in data["predictions"] ] references = [reference for reference in data["references"]] results = metric._compute( predictions, references, n_bins=n_bins, scheme=scheme, proxy=proxy, p=p, detail=True, ) print(results) plot = reliability_plot(results) return results["ECE"], plot outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")] # outputs[1].value = default_plot().__dict__ #Does not work; yet needs to be JSON encoded iface = gr.Interface( fn=compute_and_plot, inputs=[component] + sliders, outputs=outputs, description=metric.info.description, article=evaluate.utils.parse_readme(local_path / "README.md"), title=f"Metric: {metric.name}", # examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs. ).launch()