TTP / mmseg /models /assigners /hungarian_assigner.py
KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
3.42 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
from mmengine import ConfigDict
from mmengine.structures import InstanceData
from scipy.optimize import linear_sum_assignment
from torch.cuda.amp import autocast
from mmseg.registry import TASK_UTILS
from .base_assigner import BaseAssigner
@TASK_UTILS.register_module()
class HungarianAssigner(BaseAssigner):
"""Computes one-to-one matching between prediction masks and ground truth.
This class uses bipartite matching-based assignment to computes an
assignment between the prediction masks and the ground truth. The
assignment result is based on the weighted sum of match costs. The
Hungarian algorithm is used to calculate the best matching with the
minimum cost. The prediction masks that are not matched are classified
as background.
Args:
match_costs (ConfigDict|List[ConfigDict]): Match cost configs.
"""
def __init__(
self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
ConfigDict]
) -> None:
if isinstance(match_costs, dict):
match_costs = [match_costs]
elif isinstance(match_costs, list):
assert len(match_costs) > 0, \
'match_costs must not be a empty list.'
self.match_costs = [
TASK_UTILS.build(match_cost) for match_cost in match_costs
]
def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
**kwargs):
"""Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The assignment first calculates the cost for each
category assigned to each query mask, and then uses the
Hungarian algorithm to calculate the minimum cost as the best
match.
Args:
pred_instances (InstanceData): Instances of model
predictions. It includes "masks", with shape
(n, h, w) or (n, l), and "cls", with shape (n, num_classes+1)
gt_instances (InstanceData): Ground truth of instance
annotations. It includes "labels", with shape (k, ),
and "masks", with shape (k, h, w) or (k, l).
Returns:
matched_quiery_inds (Tensor): The indexes of matched quieres.
matched_label_inds (Tensor): The indexes of matched labels.
"""
# compute weighted cost
cost_list = []
with autocast(enabled=False):
for match_cost in self.match_costs:
cost = match_cost(
pred_instances=pred_instances, gt_instances=gt_instances)
cost_list.append(cost)
cost = torch.stack(cost_list).sum(dim=0)
device = cost.device
# do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" '
'to install scipy first.')
matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost)
matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device)
matched_label_inds = torch.from_numpy(matched_label_inds).to(device)
return matched_quiery_inds, matched_label_inds