jordyvl commited on
Commit
daa7784
1 Parent(s): 0c94397

functio al

Browse files
Files changed (1) hide show
  1. ece.py +6 -19
ece.py CHANGED
@@ -75,8 +75,7 @@ def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
75
  if P is None:
76
  bin_range = [0, 1] # no way to know range
77
  else:
78
- if scheme == "equal-range":
79
- bin_range = [min(P), max(P)]
80
 
81
  if scheme == "equal-range":
82
  bins = np.linspace(bin_range[0], bin_range[1], n_bins + 1) # equal range
@@ -168,13 +167,12 @@ def CE_estimate(y_correct, P, bins=None, p=1, proxy="upper-edge"):
168
 
169
  empirical_acc, bin_edges, bin_assignment = manual_binned_statistic(P, y_correct, bins)
170
  bin_numbers, weights_ece = np.unique(bin_assignment, return_counts=True)
171
- anindices = bin_numbers - 1 # reduce bin counts; left edge; indexes right BY DEFAULT
172
 
173
  # Expected calibration error
174
  if p < np.inf: # L^p-CE
175
  CE = np.average(
176
- abs(empirical_acc[anindices] - calibrated_acc[anindices]) ** p,
177
- weights=weights_ece, # weighted average 1/binfreq
178
  )
179
  elif np.isinf(p): # max-ECE
180
  CE = np.max(abs(empirical_acc[anindices] - calibrated_acc[anindices]))
@@ -193,8 +191,6 @@ def top_1_CE(Y, P, **kwargs):
193
 
194
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
195
  class ECE(evaluate.EvaluationModule):
196
- """TODO: Short description of my evaluation module."""
197
-
198
  """
199
  0. create binning scheme [discretization of f]
200
  1. build histogram P(f(X))
@@ -203,15 +199,11 @@ class ECE(evaluate.EvaluationModule):
203
  4. apply L^p norm distance and weights
204
  """
205
 
206
- # have to add to initialization here?
207
- # create bins using the params
208
- # create proxy
209
-
210
  def __init__(self, n_bins=10, bin_range=None, scheme="equal-range", proxy="upper-edge", p=1):
211
- super().__init__(self)
212
 
213
- self.bin_range = bin_range
214
  self.n_bins = n_bins
 
215
  self.scheme = scheme
216
  self.proxy = proxy
217
  self.p = p
@@ -245,8 +237,7 @@ class ECE(evaluate.EvaluationModule):
245
 
246
  def _compute(self, predictions, references):
247
  """Returns the scores"""
248
-
249
- ECE = top_1_CE(references, predictions)
250
  return {
251
  "ECE": ECE,
252
  }
@@ -274,7 +265,3 @@ def test_ECE():
274
 
275
  if __name__ == "__main__":
276
  test_ECE()
277
-
278
-
279
- # if scheme == "equal-mass":
280
- # raise AssertionError("Need to calculate based on P") #so cannot instantiate yet
 
75
  if P is None:
76
  bin_range = [0, 1] # no way to know range
77
  else:
78
+ bin_range = [min(P), max(P)]
 
79
 
80
  if scheme == "equal-range":
81
  bins = np.linspace(bin_range[0], bin_range[1], n_bins + 1) # equal range
 
167
 
168
  empirical_acc, bin_edges, bin_assignment = manual_binned_statistic(P, y_correct, bins)
169
  bin_numbers, weights_ece = np.unique(bin_assignment, return_counts=True)
170
+ anindices = bin_numbers - 1 # reduce bin counts; left edge; indexes right by default
171
 
172
  # Expected calibration error
173
  if p < np.inf: # L^p-CE
174
  CE = np.average(
175
+ abs(empirical_acc[anindices] - calibrated_acc[anindices]) ** p, weights=weights_ece
 
176
  )
177
  elif np.isinf(p): # max-ECE
178
  CE = np.max(abs(empirical_acc[anindices] - calibrated_acc[anindices]))
 
191
 
192
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
193
  class ECE(evaluate.EvaluationModule):
 
 
194
  """
195
  0. create binning scheme [discretization of f]
196
  1. build histogram P(f(X))
 
199
  4. apply L^p norm distance and weights
200
  """
201
 
 
 
 
 
202
  def __init__(self, n_bins=10, bin_range=None, scheme="equal-range", proxy="upper-edge", p=1):
203
+ #super().__init__(self)
204
 
 
205
  self.n_bins = n_bins
206
+ self.bin_range = bin_range
207
  self.scheme = scheme
208
  self.proxy = proxy
209
  self.p = p
 
237
 
238
  def _compute(self, predictions, references):
239
  """Returns the scores"""
240
+ ECE = top_1_CE(references, predictions, **self.__dict__)
 
241
  return {
242
  "ECE": ECE,
243
  }
 
265
 
266
  if __name__ == "__main__":
267
  test_ECE()