KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
1.71 kB
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Callable
import torch.distributed as dist
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.utilities.distributed import gather_all_tensors
from mmpl.registry import METRICS
@METRICS.register_module(force=True)
class PLMeanAveragePrecision(MeanAveragePrecision):
def __init__(
self,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None:
super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group)
if self.iou_type == "segm":
self.detections = self._gather_tuple_list(self.detections, process_group)
self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group)
@staticmethod
def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]:
world_size = dist.get_world_size(group=process_group)
list_gathered = [None] * world_size
dist.all_gather_object(list_gathered, list_to_gather, group=process_group)
for rank in range(1, world_size):
assert (
len(list_gathered[rank]) == list_gathered[0],
f"Rank{rank} doesn't have the same number of elements as Rank0: "
f"{list_gathered[rank]} vs. {list_gathered[0]}",
)
list_merged = []
for idx in range(len(list_gathered[0])):
for rank in range(world_size):
list_merged.append(list_gathered[rank][idx])
return list_merged