Spaces:
Runtime error
Runtime error
app and ece done
Browse files
app.py
CHANGED
@@ -7,6 +7,13 @@ import gradio as gr
|
|
7 |
from evaluate.utils import launch_gradio_widget
|
8 |
from ece import ECE
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
sliders = [
|
12 |
gr.Slider(0, 100, value=10, label="n_bins"),
|
@@ -44,16 +51,6 @@ Switch inputs and compute_fn
|
|
44 |
"""
|
45 |
|
46 |
def reliability_plot(results):
|
47 |
-
#CE, calibrated_acc, empirical_acc, weights_ece
|
48 |
-
#{"ECE": ECE[0], "y_bar": ECE[1], "p_bar": ECE[2], "bin_freq": ECE[3]}
|
49 |
-
import matplotlib.pyplot as plt
|
50 |
-
import seaborn as sns
|
51 |
-
sns.set_style('white')
|
52 |
-
sns.set_context("paper", font_scale=1) # 2
|
53 |
-
# plt.rcParams['figure.figsize'] = [10, 7]
|
54 |
-
plt.rcParams['figure.dpi'] = 300
|
55 |
-
|
56 |
-
|
57 |
fig = plt.figure()
|
58 |
ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
|
59 |
ax2 = plt.subplot2grid((3, 1), (2, 0))
|
@@ -65,9 +62,10 @@ def reliability_plot(results):
|
|
65 |
] # np.linspace(0, 1, n_bins)
|
66 |
# if upper edge then minus binsize; same for center [but half]
|
67 |
|
|
|
68 |
ax1.plot(
|
69 |
-
|
70 |
-
|
71 |
color="darkgreen",
|
72 |
ls="dotted",
|
73 |
label="Perfect",
|
@@ -79,7 +77,7 @@ def reliability_plot(results):
|
|
79 |
bin_freqs[anindices] = results["bin_freq"]
|
80 |
ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs)
|
81 |
|
82 |
-
widths = np.diff(results["y_bar"])
|
83 |
for j, bin in enumerate(results["y_bar"]):
|
84 |
perfect = results["y_bar"][j]
|
85 |
empirical = results["p_bar"][j]
|
@@ -87,7 +85,7 @@ def reliability_plot(results):
|
|
87 |
if np.isnan(empirical):
|
88 |
continue
|
89 |
|
90 |
-
ax1.bar([perfect], height=[empirical], width=-
|
91 |
|
92 |
if perfect == empirical:
|
93 |
continue
|
@@ -137,10 +135,10 @@ def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p):
|
|
137 |
)
|
138 |
|
139 |
plot = reliability_plot(results)
|
140 |
-
return results["ECE"], plt.gcf()
|
141 |
|
142 |
|
143 |
-
outputs = [gr.outputs.Textbox(label="ECE"), gr.
|
144 |
|
145 |
iface = gr.Interface(
|
146 |
fn=compute_and_plot,
|
@@ -148,26 +146,5 @@ iface = gr.Interface(
|
|
148 |
outputs=outputs,
|
149 |
description=metric.info.description,
|
150 |
article=metric.info.citation,
|
151 |
-
# examples=sample_data
|
152 |
-
)
|
153 |
-
|
154 |
-
# ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.
|
155 |
-
|
156 |
-
iface.launch()
|
157 |
-
|
158 |
-
# dict = {"ECE": ECE[0], "y_bar": ECE[1], "p_bar": ECE[2], "bin_freq": ECE[3]}
|
159 |
-
|
160 |
-
# references=[0, 1, 2], predictions=)
|
161 |
-
# https://gradio.app/getting_started/#multiple-inputs-and-outputs
|
162 |
-
## fix with sliders for all kwargs
|
163 |
-
|
164 |
-
|
165 |
-
"""
|
166 |
-
DEV: #might be nice to also plot reliability diagram
|
167 |
-
have sliders for kwargs :)
|
168 |
-
|
169 |
-
|
170 |
-
metric = ECE()
|
171 |
-
|
172 |
-
|
173 |
-
"""
|
|
|
7 |
from evaluate.utils import launch_gradio_widget
|
8 |
from ece import ECE
|
9 |
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import seaborn as sns
|
12 |
+
sns.set_style('white')
|
13 |
+
sns.set_context("paper", font_scale=1) # 2
|
14 |
+
# plt.rcParams['figure.figsize'] = [10, 7]
|
15 |
+
plt.rcParams['figure.dpi'] = 300
|
16 |
+
plt.switch_backend('agg') #; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop
|
17 |
|
18 |
sliders = [
|
19 |
gr.Slider(0, 100, value=10, label="n_bins"),
|
|
|
51 |
"""
|
52 |
|
53 |
def reliability_plot(results):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
fig = plt.figure()
|
55 |
ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
|
56 |
ax2 = plt.subplot2grid((3, 1), (2, 0))
|
|
|
62 |
] # np.linspace(0, 1, n_bins)
|
63 |
# if upper edge then minus binsize; same for center [but half]
|
64 |
|
65 |
+
ranged = np.linspace(bin_range[0], bin_range[1], n_bins)
|
66 |
ax1.plot(
|
67 |
+
ranged,
|
68 |
+
ranged,
|
69 |
color="darkgreen",
|
70 |
ls="dotted",
|
71 |
label="Perfect",
|
|
|
77 |
bin_freqs[anindices] = results["bin_freq"]
|
78 |
ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs)
|
79 |
|
80 |
+
#widths = np.diff(results["y_bar"])
|
81 |
for j, bin in enumerate(results["y_bar"]):
|
82 |
perfect = results["y_bar"][j]
|
83 |
empirical = results["p_bar"][j]
|
|
|
85 |
if np.isnan(empirical):
|
86 |
continue
|
87 |
|
88 |
+
ax1.bar([perfect], height=[empirical], width=-ranged[j], align="edge", color="lightblue")
|
89 |
|
90 |
if perfect == empirical:
|
91 |
continue
|
|
|
135 |
)
|
136 |
|
137 |
plot = reliability_plot(results)
|
138 |
+
return results["ECE"], plot #plt.gcf()
|
139 |
|
140 |
|
141 |
+
outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")]
|
142 |
|
143 |
iface = gr.Interface(
|
144 |
fn=compute_and_plot,
|
|
|
146 |
outputs=outputs,
|
147 |
description=metric.info.description,
|
148 |
article=metric.info.citation,
|
149 |
+
# examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.
|
150 |
+
).launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ece.py
CHANGED
@@ -80,7 +80,7 @@ BAD_WORDS_URL = ""
|
|
80 |
def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
|
81 |
assert scheme in [
|
82 |
"equal-range",
|
83 |
-
"equal-
|
84 |
], f"This binning scheme {scheme} is not implemented yet"
|
85 |
|
86 |
if bin_range is None:
|
@@ -106,8 +106,9 @@ def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
|
|
106 |
# rightmost entry per equal size group
|
107 |
for cur_group in range(n_bins - 1):
|
108 |
bin_upper_edges += [max(groups[cur_group])]
|
109 |
-
bin_upper_edges += [np.inf] # always +1 for right edges
|
110 |
bins = np.array(bin_upper_edges)
|
|
|
111 |
|
112 |
return bins
|
113 |
|
@@ -201,7 +202,7 @@ def top_1_CE(Y, P, **kwargs):
|
|
201 |
n_bins=kwargs["n_bins"], bin_range=kwargs["bin_range"], scheme=kwargs["scheme"], P=p_max
|
202 |
)
|
203 |
CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
|
204 |
-
if
|
205 |
return {"ECE": CE[0], "y_bar": CE[1], "p_bar": CE[2], "bin_freq": CE[3], "p_bar_cont": np.mean(p_max,-1), "accuracy": np.mean(y_correct)}
|
206 |
return CE
|
207 |
|
|
|
80 |
def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
|
81 |
assert scheme in [
|
82 |
"equal-range",
|
83 |
+
"equal-mass",
|
84 |
], f"This binning scheme {scheme} is not implemented yet"
|
85 |
|
86 |
if bin_range is None:
|
|
|
106 |
# rightmost entry per equal size group
|
107 |
for cur_group in range(n_bins - 1):
|
108 |
bin_upper_edges += [max(groups[cur_group])]
|
109 |
+
bin_upper_edges += [1.01] #[np.inf] # always +1 for right edges
|
110 |
bins = np.array(bin_upper_edges)
|
111 |
+
#OverflowError: cannot convert float infinity to integer
|
112 |
|
113 |
return bins
|
114 |
|
|
|
202 |
n_bins=kwargs["n_bins"], bin_range=kwargs["bin_range"], scheme=kwargs["scheme"], P=p_max
|
203 |
)
|
204 |
CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
|
205 |
+
if kwargs["detail"]:
|
206 |
return {"ECE": CE[0], "y_bar": CE[1], "p_bar": CE[2], "bin_freq": CE[3], "p_bar_cont": np.mean(p_max,-1), "accuracy": np.mean(y_correct)}
|
207 |
return CE
|
208 |
|