|
from enum import Enum |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
|
|
IGNORE_INDEX = -100 |
|
|
|
class Summary(Enum): |
|
NONE = 0 |
|
AVERAGE = 1 |
|
SUM = 2 |
|
COUNT = 3 |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): |
|
self.name = name |
|
self.fmt = fmt |
|
self.summary_type = summary_type |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def all_reduce(self): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
if isinstance(self.sum, np.ndarray): |
|
total = torch.tensor( |
|
self.sum.tolist() |
|
+ [ |
|
self.count, |
|
], |
|
dtype=torch.float32, |
|
device=device, |
|
) |
|
else: |
|
total = torch.tensor( |
|
[self.sum, self.count], dtype=torch.float32, device=device |
|
) |
|
|
|
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) |
|
if total.shape[0] > 2: |
|
self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item() |
|
else: |
|
self.sum, self.count = total.tolist() |
|
self.avg = self.sum / (self.count + 1e-5) |
|
|
|
def __str__(self): |
|
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" |
|
return fmtstr.format(**self.__dict__) |
|
|
|
def summary(self): |
|
fmtstr = "" |
|
if self.summary_type is Summary.NONE: |
|
fmtstr = "" |
|
elif self.summary_type is Summary.AVERAGE: |
|
fmtstr = "{name} {avg:.3f}" |
|
elif self.summary_type is Summary.SUM: |
|
fmtstr = "{name} {sum:.3f}" |
|
elif self.summary_type is Summary.COUNT: |
|
fmtstr = "{name} {count:.3f}" |
|
else: |
|
raise ValueError("invalid summary type %r" % self.summary_type) |
|
|
|
return fmtstr.format(**self.__dict__) |
|
|
|
|
|
def intersectionAndUnionGPU(output, target, K, ignore_index=255): |
|
|
|
assert output.dim() in [1, 2, 3] |
|
assert output.shape == target.shape |
|
output = output.view(-1) |
|
target = target.view(-1) |
|
output[target == ignore_index] = ignore_index |
|
intersection = output[output == target] |
|
area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) |
|
area_output = torch.histc(output, bins=K, min=0, max=K - 1) |
|
area_target = torch.histc(target, bins=K, min=0, max=K - 1) |
|
area_union = area_output + area_target - area_intersection |
|
return area_intersection, area_union, area_target |
|
|
|
|
|
class ProgressMeter(object): |
|
def __init__(self, num_batches, meters, prefix=""): |
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
|
self.meters = meters |
|
self.prefix = prefix |
|
|
|
def display(self, batch): |
|
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
|
entries += [str(meter) for meter in self.meters] |
|
print("\t".join(entries)) |
|
|
|
def display_summary(self): |
|
entries = [" *"] |
|
entries += [meter.summary() for meter in self.meters] |
|
print(" ".join(entries)) |
|
|
|
def _get_batch_fmtstr(self, num_batches): |
|
num_digits = len(str(num_batches // 1)) |
|
fmt = "{:" + str(num_digits) + "d}" |
|
return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
|
|
|
|
|
def dict_to_cuda(input_dict): |
|
for k, v in input_dict.items(): |
|
if isinstance(input_dict[k], torch.Tensor): |
|
input_dict[k] = v.cuda(non_blocking=True) |
|
elif ( |
|
isinstance(input_dict[k], list) |
|
and len(input_dict[k]) > 0 |
|
and isinstance(input_dict[k][0], torch.Tensor) |
|
): |
|
input_dict[k] = [ele.cuda(non_blocking=True) for ele in v] |
|
return input_dict |
|
|