update call to `mean_squared_error` to comply with sklearn v1.6

#3
by Natooz - opened
Files changed (1) hide show
  1. mse.py +8 -3
mse.py CHANGED
@@ -112,8 +112,13 @@ class Mse(evaluate.Metric):
112
 
113
  def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
114
 
115
- mse = mean_squared_error(
116
- references, predictions, sample_weight=sample_weight, multioutput=multioutput, squared=squared
117
- )
 
 
 
 
 
118
 
119
  return {"mse": mse}
 
112
 
113
  def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
114
 
115
+ if squared:
116
+ mse = mean_squared_error(
117
+ references, predictions, sample_weight=sample_weight, multioutput=multioutput
118
+ )
119
+ else:
120
+ mse = root_mean_squared_error(
121
+ references, predictions, sample_weight=sample_weight, multioutput=multioutput
122
+ )
123
 
124
  return {"mse": mse}