franzi2505 commited on
Commit
d6001c4
1 Parent(s): c9750d3

Update PanopticQuality.py

Browse files
Files changed (1) hide show
  1. PanopticQuality.py +3 -3
PanopticQuality.py CHANGED
@@ -109,6 +109,7 @@ class PQMetric(evaluate.Metric):
109
  DEFAULT_STUFF = ["WATER", "SKY", "LAND", "CONSTRUCTION", "ICE", "OWN_BOAT"]
110
 
111
  self.label2id = label2id if label2id is not None else DEFAULT_LABEL2ID
 
112
  self.stuff = stuff if stuff is not None else DEFAULT_STUFF
113
  self.per_class = per_class
114
  self.split_sq_rq = split_sq_rq
@@ -171,7 +172,6 @@ class PQMetric(evaluate.Metric):
171
  fn = self.pq_metric.metric.false_negatives.clone()
172
  iou = self.pq_metric.metric.iou_sum.clone()
173
 
174
- id2label = {id: label for label, id in self.label2id.items()}
175
  things_stuffs = sorted(self.pq_metric.things) + sorted(self.pq_metric.stuffs)
176
 
177
  # compute scores
@@ -182,10 +182,10 @@ class PQMetric(evaluate.Metric):
182
  if self.per_class:
183
  if not self.split_sq_rq:
184
  result = result.T
185
- result_dict["scores"] = {id2label[numeric_label]: result[i].tolist() \
186
  for i, numeric_label in enumerate(things_stuffs)}
187
  result_dict["scores"].update({"ALL": result.mean(axis=0).tolist()})
188
- result_dict["numbers"] = {id2label[numeric_label]: [tp[i].item(), fp[i].item(), fn[i].item(), iou[i].item()] \
189
  for i, numeric_label in enumerate(things_stuffs)}
190
  result_dict["numbers"].update({"ALL": [tp.sum().item(), fp.sum().item(), fn.sum().item(), iou.sum().item()]})
191
  else:
 
109
  DEFAULT_STUFF = ["WATER", "SKY", "LAND", "CONSTRUCTION", "ICE", "OWN_BOAT"]
110
 
111
  self.label2id = label2id if label2id is not None else DEFAULT_LABEL2ID
112
+ self.id2label = {id: label for label, id in self.label2id.items()}
113
  self.stuff = stuff if stuff is not None else DEFAULT_STUFF
114
  self.per_class = per_class
115
  self.split_sq_rq = split_sq_rq
 
172
  fn = self.pq_metric.metric.false_negatives.clone()
173
  iou = self.pq_metric.metric.iou_sum.clone()
174
 
 
175
  things_stuffs = sorted(self.pq_metric.things) + sorted(self.pq_metric.stuffs)
176
 
177
  # compute scores
 
182
  if self.per_class:
183
  if not self.split_sq_rq:
184
  result = result.T
185
+ result_dict["scores"] = {self.id2label[numeric_label]: result[i].tolist() \
186
  for i, numeric_label in enumerate(things_stuffs)}
187
  result_dict["scores"].update({"ALL": result.mean(axis=0).tolist()})
188
+ result_dict["numbers"] = {self.id2label[numeric_label]: [tp[i].item(), fp[i].item(), fn[i].item(), iou[i].item()] \
189
  for i, numeric_label in enumerate(things_stuffs)}
190
  result_dict["numbers"].update({"ALL": [tp.sum().item(), fp.sum().item(), fn.sum().item(), iou.sum().item()]})
191
  else: