_
File size: 540 Bytes
da3eeba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torchvision.ops.boxes import box_iou


def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
    order = torch.argsort(-scores)
    keep = []

    while order.numel() > 0:
        i = order[0]
        keep.append(i.item())

        if order.numel() == 1:
            break

        ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
        mask = ious <= iou_threshold
        order = order[1:][mask]

    return torch.tensor(keep, device=bboxes.device)