Spaces:
Runtime error
Runtime error
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Callable | |
import torch.distributed as dist | |
from torchmetrics.detection import MeanAveragePrecision | |
from torchmetrics.utilities.distributed import gather_all_tensors | |
from mmpl.registry import METRICS | |
class PLMeanAveragePrecision(MeanAveragePrecision): | |
def __init__( | |
self, | |
*args, | |
**kwargs, | |
) -> None: | |
super().__init__(*args, **kwargs) | |
def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None: | |
super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) | |
if self.iou_type == "segm": | |
self.detections = self._gather_tuple_list(self.detections, process_group) | |
self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) | |
def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: | |
world_size = dist.get_world_size(group=process_group) | |
list_gathered = [None] * world_size | |
dist.all_gather_object(list_gathered, list_to_gather, group=process_group) | |
for rank in range(1, world_size): | |
assert ( | |
len(list_gathered[rank]) == list_gathered[0], | |
f"Rank{rank} doesn't have the same number of elements as Rank0: " | |
f"{list_gathered[rank]} vs. {list_gathered[0]}", | |
) | |
list_merged = [] | |
for idx in range(len(list_gathered[0])): | |
for rank in range(world_size): | |
list_merged.append(list_gathered[rank][idx]) | |
return list_merged | |