|
|
|
import torch
|
|
|
|
from detectron2.layers import nonzero_tuple
|
|
|
|
__all__ = ["subsample_labels"]
|
|
|
|
|
|
def subsample_labels(
|
|
labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int
|
|
):
|
|
"""
|
|
Return `num_samples` (or fewer, if not enough found)
|
|
random samples from `labels` which is a mixture of positives & negatives.
|
|
It will try to return as many positives as possible without
|
|
exceeding `positive_fraction * num_samples`, and then try to
|
|
fill the remaining slots with negatives.
|
|
|
|
Args:
|
|
labels (Tensor): (N, ) label vector with values:
|
|
* -1: ignore
|
|
* bg_label: background ("negative") class
|
|
* otherwise: one or more foreground ("positive") classes
|
|
num_samples (int): The total number of labels with value >= 0 to return.
|
|
Values that are not sampled will be filled with -1 (ignore).
|
|
positive_fraction (float): The number of subsampled labels with values > 0
|
|
is `min(num_positives, int(positive_fraction * num_samples))`. The number
|
|
of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
|
|
In order words, if there are not enough positives, the sample is filled with
|
|
negatives. If there are also not enough negatives, then as many elements are
|
|
sampled as is possible.
|
|
bg_label (int): label index of background ("negative") class.
|
|
|
|
Returns:
|
|
pos_idx, neg_idx (Tensor):
|
|
1D vector of indices. The total length of both is `num_samples` or fewer.
|
|
"""
|
|
positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0]
|
|
negative = nonzero_tuple(labels == bg_label)[0]
|
|
|
|
num_pos = int(num_samples * positive_fraction)
|
|
|
|
num_pos = min(positive.numel(), num_pos)
|
|
num_neg = num_samples - num_pos
|
|
|
|
num_neg = min(negative.numel(), num_neg)
|
|
|
|
|
|
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
|
|
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
|
|
|
|
pos_idx = positive[perm1]
|
|
neg_idx = negative[perm2]
|
|
return pos_idx, neg_idx
|
|
|