Spaces:
Runtime error
Runtime error
Fix to reliability diagram - correct with test
Browse files
app.py
CHANGED
@@ -48,34 +48,31 @@ def reliability_plot(results):
|
|
48 |
# DEV: nicer would be to plot like a polygon
|
49 |
# see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py
|
50 |
|
51 |
-
def over_under_confidence(results):
|
52 |
-
colors = []
|
53 |
-
for j, bin in enumerate(results["y_bar"]):
|
54 |
-
perfect = results["y_bar"][j]
|
55 |
-
empirical = results["p_bar"][j]
|
56 |
-
|
57 |
-
bin_color = (
|
58 |
-
"limegreen"
|
59 |
-
if np.allclose(perfect, empirical)
|
60 |
-
else "dodgerblue"
|
61 |
-
if empirical < perfect
|
62 |
-
else "orangered"
|
63 |
-
)
|
64 |
-
colors.append(bin_color)
|
65 |
-
return colors
|
66 |
-
|
67 |
fig, ax1, ax2 = default_plot()
|
68 |
|
69 |
# Bin differences
|
70 |
bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
75 |
)
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
ax1handles = [
|
81 |
mpatches.Patch(color="orangered", label="Overconfident"),
|
@@ -84,12 +81,11 @@ def reliability_plot(results):
|
|
84 |
]
|
85 |
|
86 |
# Bin frequencies
|
87 |
-
anindices = np.where(~np.isnan(results["p_bar"]
|
88 |
-
|
89 |
-
bin_freqs = np.zeros(n_bins)
|
90 |
bin_freqs[anindices] = results["bin_freq"]
|
91 |
-
|
92 |
-
|
93 |
)
|
94 |
|
95 |
acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy")
|
@@ -148,8 +144,8 @@ component = gr.inputs.Dataframe(
|
|
148 |
)
|
149 |
|
150 |
component.value = [
|
151 |
-
[[0.
|
152 |
-
[[0.
|
153 |
[[0, 0.95, 0.05], 1],
|
154 |
]
|
155 |
sample_data = [[component] + slider_defaults]
|
|
|
48 |
# DEV: nicer would be to plot like a polygon
|
49 |
# see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
fig, ax1, ax2 = default_plot()
|
52 |
|
53 |
# Bin differences
|
54 |
bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0)
|
55 |
+
bins_with_right_edge = np.insert(results["y_bar"], -1, 1.0, axis=0)
|
56 |
+
bins_with_leftright_edge = np.insert(bins_with_left_edge, -1, 1.0, axis=0)
|
57 |
+
weights = np.nan_to_num(results["p_bar"], copy=True, nan=0)
|
58 |
+
|
59 |
+
# NOTE: the histogram API is strange
|
60 |
+
_, _, patches = ax1.hist(
|
61 |
+
bins_with_left_edge,
|
62 |
+
weights=weights,
|
63 |
+
bins=bins_with_leftright_edge,
|
64 |
)
|
65 |
+
for b in range(len(patches)):
|
66 |
+
perfect = bins_with_right_edge[b] # if b != n_bins else
|
67 |
+
empirical = weights[b] # patches[b]._height
|
68 |
+
bin_color = (
|
69 |
+
"limegreen"
|
70 |
+
if perfect == empirical
|
71 |
+
else "dodgerblue"
|
72 |
+
if empirical < perfect
|
73 |
+
else "orangered"
|
74 |
+
)
|
75 |
+
patches[b].set_facecolor(bin_color) # color based on over/underconfidence
|
76 |
|
77 |
ax1handles = [
|
78 |
mpatches.Patch(color="orangered", label="Overconfident"),
|
|
|
81 |
]
|
82 |
|
83 |
# Bin frequencies
|
84 |
+
anindices = np.where(~np.isnan(results["p_bar"]))[0]
|
85 |
+
bin_freqs = np.zeros(len(results["p_bar"]))
|
|
|
86 |
bin_freqs[anindices] = results["bin_freq"]
|
87 |
+
ax2.hist(
|
88 |
+
bins_with_left_edge, weights=bin_freqs, color="midnightblue", bins=bins_with_leftright_edge
|
89 |
)
|
90 |
|
91 |
acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy")
|
|
|
144 |
)
|
145 |
|
146 |
component.value = [
|
147 |
+
[[0.6, 0.2, 0.2], 0],
|
148 |
+
[[0.7, 0.1, 0.2], 2],
|
149 |
[[0, 0.95, 0.05], 1],
|
150 |
]
|
151 |
sample_data = [[component] + slider_defaults]
|
ece.py
CHANGED
@@ -21,7 +21,6 @@ import numpy as np
|
|
21 |
from typing import Dict, Optional
|
22 |
|
23 |
|
24 |
-
|
25 |
# TODO: Add BibTeX citation
|
26 |
_CITATION = """\
|
27 |
@InProceedings{huggingface:module,
|
@@ -103,9 +102,9 @@ def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
|
|
103 |
# rightmost entry per equal size group
|
104 |
for cur_group in range(n_bins - 1):
|
105 |
bin_upper_edges += [max(groups[cur_group])]
|
106 |
-
bin_upper_edges += [1.01]
|
107 |
bins = np.array(bin_upper_edges)
|
108 |
-
#OverflowError: cannot convert float infinity to integer
|
109 |
|
110 |
return bins
|
111 |
|
@@ -200,7 +199,14 @@ def top_1_CE(Y, P, **kwargs):
|
|
200 |
)
|
201 |
CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
|
202 |
if kwargs["detail"]:
|
203 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
return CE
|
205 |
|
206 |
|
@@ -306,9 +312,18 @@ def test_ECE():
|
|
306 |
print(f"ECE: {res['ECE']}")
|
307 |
|
308 |
res = ECE()._compute(predictions, references, detail=True)
|
309 |
-
import pdb; pdb.set_trace() # breakpoint 25274412 //
|
310 |
-
|
311 |
print(f"ECE: {res['ECE']}")
|
312 |
|
313 |
-
|
314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
from typing import Dict, Optional
|
22 |
|
23 |
|
|
|
24 |
# TODO: Add BibTeX citation
|
25 |
_CITATION = """\
|
26 |
@InProceedings{huggingface:module,
|
|
|
102 |
# rightmost entry per equal size group
|
103 |
for cur_group in range(n_bins - 1):
|
104 |
bin_upper_edges += [max(groups[cur_group])]
|
105 |
+
bin_upper_edges += [1.01] # [np.inf] # always +1 for right edges
|
106 |
bins = np.array(bin_upper_edges)
|
107 |
+
# OverflowError: cannot convert float infinity to integer
|
108 |
|
109 |
return bins
|
110 |
|
|
|
199 |
)
|
200 |
CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
|
201 |
if kwargs["detail"]:
|
202 |
+
return {
|
203 |
+
"ECE": CE[0],
|
204 |
+
"y_bar": CE[1],
|
205 |
+
"p_bar": CE[2],
|
206 |
+
"bin_freq": CE[3],
|
207 |
+
"p_bar_cont": np.mean(p_max, -1),
|
208 |
+
"accuracy": np.mean(y_correct),
|
209 |
+
}
|
210 |
return CE
|
211 |
|
212 |
|
|
|
312 |
print(f"ECE: {res['ECE']}")
|
313 |
|
314 |
res = ECE()._compute(predictions, references, detail=True)
|
|
|
|
|
315 |
print(f"ECE: {res['ECE']}")
|
316 |
|
317 |
+
|
318 |
+
def test_deterministic():
|
319 |
+
res = ECE()._compute(
|
320 |
+
references=[0, 1, 2],
|
321 |
+
predictions=[[0.63, 0.2, 0.2], [0, 0.95, 0.05], [0.72, 0.1, 0.2]],
|
322 |
+
detail=True,
|
323 |
+
)
|
324 |
+
print(f"ECE: {res['ECE']}\n {res}")
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == "__main__":
|
328 |
+
test_deterministic()
|
329 |
+
test_ECE()
|