Natooz commited on
Commit
325581b
·
verified ·
1 Parent(s): 4f846cc

update call to `mean_squared_error` to comply with sklearn v1.6

Browse files

The `squared` argument was depreciated since v1.4 and has been removed in v1.6, making this evaluation module incompatible with the latest scikit-learn version.
This PR make it call the `root_mean_squared_error` when `squared=True`, solving this error.

Files changed (1) hide show
  1. mse.py +8 -3
mse.py CHANGED
@@ -112,8 +112,13 @@ class Mse(evaluate.Metric):
112
 
113
  def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
114
 
115
- mse = mean_squared_error(
116
- references, predictions, sample_weight=sample_weight, multioutput=multioutput, squared=squared
117
- )
 
 
 
 
 
118
 
119
  return {"mse": mse}
 
112
 
113
  def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
114
 
115
+ if squared:
116
+ mse = mean_squared_error(
117
+ references, predictions, sample_weight=sample_weight, multioutput=multioutput
118
+ )
119
+ else:
120
+ mse = root_mean_squared_error(
121
+ references, predictions, sample_weight=sample_weight, multioutput=multioutput
122
+ )
123
 
124
  return {"mse": mse}