lvwerra HF staff commited on
Commit
7b31dac
1 Parent(s): 9b7175f

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. mean_iou.py +29 -11
  2. requirements.txt +1 -1
mean_iou.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
  """Mean IoU (Intersection-over-Union) metric."""
15
 
 
16
  from typing import Dict, Optional
17
 
18
  import datasets
@@ -273,13 +274,29 @@ def mean_iou(
273
  return metrics
274
 
275
 
 
 
 
 
 
 
 
 
 
 
 
276
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
277
  class MeanIoU(evaluate.Metric):
278
- def _info(self):
 
 
 
 
279
  return evaluate.MetricInfo(
280
  description=_DESCRIPTION,
281
  citation=_CITATION,
282
  inputs_description=_KWARGS_DESCRIPTION,
 
283
  features=datasets.Features(
284
  # 1st Seq - height dim, 2nd - width dim
285
  {
@@ -296,19 +313,20 @@ class MeanIoU(evaluate.Metric):
296
  self,
297
  predictions,
298
  references,
299
- num_labels: int,
300
- ignore_index: bool,
301
- nan_to_num: Optional[int] = None,
302
- label_map: Optional[Dict[int, int]] = None,
303
- reduce_labels: bool = False,
304
  ):
 
 
 
 
 
 
305
  iou_result = mean_iou(
306
  results=predictions,
307
  gt_seg_maps=references,
308
- num_labels=num_labels,
309
- ignore_index=ignore_index,
310
- nan_to_num=nan_to_num,
311
- label_map=label_map,
312
- reduce_labels=reduce_labels,
313
  )
314
  return iou_result
 
13
  # limitations under the License.
14
  """Mean IoU (Intersection-over-Union) metric."""
15
 
16
+ from dataclasses import dataclass
17
  from typing import Dict, Optional
18
 
19
  import datasets
 
274
  return metrics
275
 
276
 
277
+ @dataclass
278
+ class MeanIoUConfig(evaluate.info.Config):
279
+ name: str = "default"
280
+
281
+ num_labels: int = None
282
+ ignore_index: int = None
283
+ nan_to_num: Optional[int] = None
284
+ label_map: Optional[Dict[int, int]] = None
285
+ reduce_labels: bool = False
286
+
287
+
288
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
289
  class MeanIoU(evaluate.Metric):
290
+
291
+ CONFIG_CLASS = MeanIoUConfig
292
+ ALLOWED_CONFIG_NAMES = ["default"]
293
+
294
+ def _info(self, config):
295
  return evaluate.MetricInfo(
296
  description=_DESCRIPTION,
297
  citation=_CITATION,
298
  inputs_description=_KWARGS_DESCRIPTION,
299
+ config=config,
300
  features=datasets.Features(
301
  # 1st Seq - height dim, 2nd - width dim
302
  {
 
313
  self,
314
  predictions,
315
  references,
 
 
 
 
 
316
  ):
317
+
318
+ if self.config.num_labels is None:
319
+ raise ValueError("You have to specify a value for `num_labels`.")
320
+ if self.config.ignore_index is None:
321
+ raise ValueError("You have to specify a value for `ignore_index`.")
322
+
323
  iou_result = mean_iou(
324
  results=predictions,
325
  gt_seg_maps=references,
326
+ num_labels=self.config.num_labels,
327
+ ignore_index=self.config.ignore_index,
328
+ nan_to_num=self.config.nan_to_num,
329
+ label_map=self.config.label_map,
330
+ reduce_labels=self.config.reduce_labels,
331
  )
332
  return iou_result
requirements.txt CHANGED
@@ -1 +1 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39