Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Union | |
import torch | |
from mmengine import ConfigDict | |
from mmengine.structures import InstanceData | |
from scipy.optimize import linear_sum_assignment | |
from torch.cuda.amp import autocast | |
from mmseg.registry import TASK_UTILS | |
from .base_assigner import BaseAssigner | |
class HungarianAssigner(BaseAssigner): | |
"""Computes one-to-one matching between prediction masks and ground truth. | |
This class uses bipartite matching-based assignment to computes an | |
assignment between the prediction masks and the ground truth. The | |
assignment result is based on the weighted sum of match costs. The | |
Hungarian algorithm is used to calculate the best matching with the | |
minimum cost. The prediction masks that are not matched are classified | |
as background. | |
Args: | |
match_costs (ConfigDict|List[ConfigDict]): Match cost configs. | |
""" | |
def __init__( | |
self, match_costs: Union[List[Union[dict, ConfigDict]], dict, | |
ConfigDict] | |
) -> None: | |
if isinstance(match_costs, dict): | |
match_costs = [match_costs] | |
elif isinstance(match_costs, list): | |
assert len(match_costs) > 0, \ | |
'match_costs must not be a empty list.' | |
self.match_costs = [ | |
TASK_UTILS.build(match_cost) for match_cost in match_costs | |
] | |
def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, | |
**kwargs): | |
"""Computes one-to-one matching based on the weighted costs. | |
This method assign each query prediction to a ground truth or | |
background. The assignment first calculates the cost for each | |
category assigned to each query mask, and then uses the | |
Hungarian algorithm to calculate the minimum cost as the best | |
match. | |
Args: | |
pred_instances (InstanceData): Instances of model | |
predictions. It includes "masks", with shape | |
(n, h, w) or (n, l), and "cls", with shape (n, num_classes+1) | |
gt_instances (InstanceData): Ground truth of instance | |
annotations. It includes "labels", with shape (k, ), | |
and "masks", with shape (k, h, w) or (k, l). | |
Returns: | |
matched_quiery_inds (Tensor): The indexes of matched quieres. | |
matched_label_inds (Tensor): The indexes of matched labels. | |
""" | |
# compute weighted cost | |
cost_list = [] | |
with autocast(enabled=False): | |
for match_cost in self.match_costs: | |
cost = match_cost( | |
pred_instances=pred_instances, gt_instances=gt_instances) | |
cost_list.append(cost) | |
cost = torch.stack(cost_list).sum(dim=0) | |
device = cost.device | |
# do Hungarian matching on CPU using linear_sum_assignment | |
cost = cost.detach().cpu() | |
if linear_sum_assignment is None: | |
raise ImportError('Please run "pip install scipy" ' | |
'to install scipy first.') | |
matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost) | |
matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device) | |
matched_label_inds = torch.from_numpy(matched_label_inds).to(device) | |
return matched_quiery_inds, matched_label_inds | |