jordyvl commited on
Commit
efa98d8
1 Parent(s): 3f722df

app and ece done

Browse files
Files changed (2) hide show
  1. app.py +16 -39
  2. ece.py +4 -3
app.py CHANGED
@@ -7,6 +7,13 @@ import gradio as gr
7
  from evaluate.utils import launch_gradio_widget
8
  from ece import ECE
9
 
 
 
 
 
 
 
 
10
 
11
  sliders = [
12
  gr.Slider(0, 100, value=10, label="n_bins"),
@@ -44,16 +51,6 @@ Switch inputs and compute_fn
44
  """
45
 
46
  def reliability_plot(results):
47
- #CE, calibrated_acc, empirical_acc, weights_ece
48
- #{"ECE": ECE[0], "y_bar": ECE[1], "p_bar": ECE[2], "bin_freq": ECE[3]}
49
- import matplotlib.pyplot as plt
50
- import seaborn as sns
51
- sns.set_style('white')
52
- sns.set_context("paper", font_scale=1) # 2
53
- # plt.rcParams['figure.figsize'] = [10, 7]
54
- plt.rcParams['figure.dpi'] = 300
55
-
56
-
57
  fig = plt.figure()
58
  ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
59
  ax2 = plt.subplot2grid((3, 1), (2, 0))
@@ -65,9 +62,10 @@ def reliability_plot(results):
65
  ] # np.linspace(0, 1, n_bins)
66
  # if upper edge then minus binsize; same for center [but half]
67
 
 
68
  ax1.plot(
69
- np.linspace(bin_range[0], bin_range[1], n_bins),
70
- np.linspace(bin_range[0], bin_range[1], n_bins),
71
  color="darkgreen",
72
  ls="dotted",
73
  label="Perfect",
@@ -79,7 +77,7 @@ def reliability_plot(results):
79
  bin_freqs[anindices] = results["bin_freq"]
80
  ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs)
81
 
82
- widths = np.diff(results["y_bar"])
83
  for j, bin in enumerate(results["y_bar"]):
84
  perfect = results["y_bar"][j]
85
  empirical = results["p_bar"][j]
@@ -87,7 +85,7 @@ def reliability_plot(results):
87
  if np.isnan(empirical):
88
  continue
89
 
90
- ax1.bar([perfect], height=[empirical], width=-widths[j], align="edge", color="lightblue")
91
 
92
  if perfect == empirical:
93
  continue
@@ -137,10 +135,10 @@ def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p):
137
  )
138
 
139
  plot = reliability_plot(results)
140
- return results["ECE"], plt.gcf()
141
 
142
 
143
- outputs = [gr.outputs.Textbox(label="ECE"), gr.outputs.Plot(label="Reliability diagram")]
144
 
145
  iface = gr.Interface(
146
  fn=compute_and_plot,
@@ -148,26 +146,5 @@ iface = gr.Interface(
148
  outputs=outputs,
149
  description=metric.info.description,
150
  article=metric.info.citation,
151
- # examples=sample_data
152
- )
153
-
154
- # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.
155
-
156
- iface.launch()
157
-
158
- # dict = {"ECE": ECE[0], "y_bar": ECE[1], "p_bar": ECE[2], "bin_freq": ECE[3]}
159
-
160
- # references=[0, 1, 2], predictions=)
161
- # https://gradio.app/getting_started/#multiple-inputs-and-outputs
162
- ## fix with sliders for all kwargs
163
-
164
-
165
- """
166
- DEV: #might be nice to also plot reliability diagram
167
- have sliders for kwargs :)
168
-
169
-
170
- metric = ECE()
171
-
172
-
173
- """
 
7
  from evaluate.utils import launch_gradio_widget
8
  from ece import ECE
9
 
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ sns.set_style('white')
13
+ sns.set_context("paper", font_scale=1) # 2
14
+ # plt.rcParams['figure.figsize'] = [10, 7]
15
+ plt.rcParams['figure.dpi'] = 300
16
+ plt.switch_backend('agg') #; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop
17
 
18
  sliders = [
19
  gr.Slider(0, 100, value=10, label="n_bins"),
 
51
  """
52
 
53
  def reliability_plot(results):
 
 
 
 
 
 
 
 
 
 
54
  fig = plt.figure()
55
  ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
56
  ax2 = plt.subplot2grid((3, 1), (2, 0))
 
62
  ] # np.linspace(0, 1, n_bins)
63
  # if upper edge then minus binsize; same for center [but half]
64
 
65
+ ranged = np.linspace(bin_range[0], bin_range[1], n_bins)
66
  ax1.plot(
67
+ ranged,
68
+ ranged,
69
  color="darkgreen",
70
  ls="dotted",
71
  label="Perfect",
 
77
  bin_freqs[anindices] = results["bin_freq"]
78
  ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs)
79
 
80
+ #widths = np.diff(results["y_bar"])
81
  for j, bin in enumerate(results["y_bar"]):
82
  perfect = results["y_bar"][j]
83
  empirical = results["p_bar"][j]
 
85
  if np.isnan(empirical):
86
  continue
87
 
88
+ ax1.bar([perfect], height=[empirical], width=-ranged[j], align="edge", color="lightblue")
89
 
90
  if perfect == empirical:
91
  continue
 
135
  )
136
 
137
  plot = reliability_plot(results)
138
+ return results["ECE"], plot #plt.gcf()
139
 
140
 
141
+ outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")]
142
 
143
  iface = gr.Interface(
144
  fn=compute_and_plot,
 
146
  outputs=outputs,
147
  description=metric.info.description,
148
  article=metric.info.citation,
149
+ # examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.
150
+ ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ece.py CHANGED
@@ -80,7 +80,7 @@ BAD_WORDS_URL = ""
80
  def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
81
  assert scheme in [
82
  "equal-range",
83
- "equal-masss",
84
  ], f"This binning scheme {scheme} is not implemented yet"
85
 
86
  if bin_range is None:
@@ -106,8 +106,9 @@ def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
106
  # rightmost entry per equal size group
107
  for cur_group in range(n_bins - 1):
108
  bin_upper_edges += [max(groups[cur_group])]
109
- bin_upper_edges += [np.inf] # always +1 for right edges
110
  bins = np.array(bin_upper_edges)
 
111
 
112
  return bins
113
 
@@ -201,7 +202,7 @@ def top_1_CE(Y, P, **kwargs):
201
  n_bins=kwargs["n_bins"], bin_range=kwargs["bin_range"], scheme=kwargs["scheme"], P=p_max
202
  )
203
  CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
204
- if self.detail:
205
  return {"ECE": CE[0], "y_bar": CE[1], "p_bar": CE[2], "bin_freq": CE[3], "p_bar_cont": np.mean(p_max,-1), "accuracy": np.mean(y_correct)}
206
  return CE
207
 
 
80
  def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
81
  assert scheme in [
82
  "equal-range",
83
+ "equal-mass",
84
  ], f"This binning scheme {scheme} is not implemented yet"
85
 
86
  if bin_range is None:
 
106
  # rightmost entry per equal size group
107
  for cur_group in range(n_bins - 1):
108
  bin_upper_edges += [max(groups[cur_group])]
109
+ bin_upper_edges += [1.01] #[np.inf] # always +1 for right edges
110
  bins = np.array(bin_upper_edges)
111
+ #OverflowError: cannot convert float infinity to integer
112
 
113
  return bins
114
 
 
202
  n_bins=kwargs["n_bins"], bin_range=kwargs["bin_range"], scheme=kwargs["scheme"], P=p_max
203
  )
204
  CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
205
+ if kwargs["detail"]:
206
  return {"ECE": CE[0], "y_bar": CE[1], "p_bar": CE[2], "bin_freq": CE[3], "p_bar_cont": np.mean(p_max,-1), "accuracy": np.mean(y_correct)}
207
  return CE
208