|
import os |
|
import torch |
|
import collections |
|
from easydict import EasyDict as edict |
|
import torch.distributed as dist |
|
|
|
from torch.nn import Module |
|
from torch.utils.data.sampler import Sampler |
|
import math |
|
import numpy as np |
|
import multiprocessing as mp |
|
import copy |
|
import random |
|
from collections import defaultdict |
|
from core.utils import named_buffers, sync_print, printlog |
|
import subprocess |
|
import socket |
|
from . import comm_ |
|
import torch.cuda.comm |
|
|
|
class DistModule(torch.nn.Module): |
|
def __init__(self, module, sync=False, task_grp=None, share_backbone_group=None, \ |
|
share_neck_group=None, share_decoder_group=None, share_adapter_group=None, |
|
ignore_bcast=None, \ |
|
task_weight=None, task_size=None): |
|
super(DistModule, self).__init__() |
|
self.module = module |
|
self.sync = sync |
|
self.task_grp = task_grp |
|
self.share_backbone_group = share_backbone_group |
|
self.share_neck_group = share_neck_group |
|
self.share_decoder_group = share_decoder_group |
|
self.share_adapter_group = share_adapter_group |
|
self.task_weight = task_weight |
|
self.task_size = task_size |
|
|
|
if not hasattr(torch.nn.Module, 'named_buffers'): |
|
printlog('registering named_buffers for nn.Module at DistModule') |
|
torch.nn.Module.named_buffers = named_buffers |
|
|
|
broadcast_params_multitask(self, self.task_grp, self.share_backbone_group, |
|
self.share_neck_group, self.share_decoder_group, |
|
self.share_adapter_group, ignore_bcast) |
|
|
|
assert sync, "Currently, only sync model is supported!" |
|
if not sync: |
|
self._grad_accs = {} |
|
self._reduce_hooks = {} |
|
self._register_hooks() |
|
|
|
def forward(self, *inputs, **kwargs): |
|
return self.module(*inputs, **kwargs) |
|
|
|
def train(self, mode=True): |
|
super(DistModule, self).train(mode) |
|
self.module.train(mode) |
|
|
|
def _register_hooks(self): |
|
for i,(name,p) in enumerate(self.named_parameters()): |
|
if p.requires_grad: |
|
p_tmp = p.expand_as(p) |
|
grad_acc = p_tmp.grad_fn.next_functions[0][0] |
|
self._reduce_hooks[name] = grad_acc.register_hook(self._make_hook(name, p, i)) |
|
self._grad_accs[name] = grad_acc |
|
|
|
def _make_hook(self, name, p, i): |
|
if not p.task_specific: |
|
def hook(*ignore): |
|
allreduce_async(name, p.grad.data) |
|
else: |
|
printlog('{} register hook as task specific'.format(name)) |
|
def hook(*ignore): |
|
|
|
allreduce(p.grad.data, group_idx=self.task_grp) |
|
return hook |
|
|
|
def reduce_gradients(self, task_specific=False): |
|
if self.sync: |
|
if not task_specific: |
|
if self.task_grp is not None or self.share_backbone_group is not None \ |
|
or self.share_neck_group is not None or self.share_decoder_group is not None: |
|
for name, param in self.named_parameters(): |
|
if param.grad is None: param.grad = param.data * 0 |
|
if param.task_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.task_grp) |
|
elif param.backbone_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_backbone_group) |
|
elif param.adapter_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_adapter_group) |
|
elif param.neck_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_neck_group) |
|
elif param.decoder_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_decoder_group) |
|
elif param.requires_grad: |
|
allreduce(param.grad.data) |
|
else: |
|
for param in self.parameters(): |
|
if param.requires_grad and param.grad is not None: |
|
dist.all_reduce(param.grad.data) |
|
else: |
|
for name, param in self.named_parameters(): |
|
if param.requires_grad and param.grad is not None: |
|
dist.all_reduce(param.grad.data, group_idx=self.task_grp) |
|
|
|
class DistModule_Hulk(torch.nn.Module): |
|
|
|
modality_names =['rgb', 'dense_labeling', 'text'] |
|
|
|
def __init__(self, module, sync=False, task_grp=None, share_backbone_group=None, |
|
share_decoder_group=None, share_rgb_group=None, share_dense_labeling_group=None, |
|
share_sparse_labeling_group=None, share_text_group=None, share_video_group=None, |
|
share_modality_group=None, |
|
ignore_bcast=None, task_weight=None, task_size=None, ): |
|
super(DistModule_Hulk, self).__init__() |
|
self.module = module |
|
self.sync = sync |
|
self.task_grp = task_grp |
|
self.share_modality_group = share_modality_group |
|
self.share_backbone_group = share_backbone_group |
|
|
|
self.share_decoder_group = share_decoder_group |
|
|
|
|
|
|
|
self.share_rgb_group = share_rgb_group |
|
self.share_dense_labeling_group = share_dense_labeling_group |
|
self.share_sparse_labeling_group = share_sparse_labeling_group |
|
self.share_text_group = share_text_group |
|
self.share_video_group = share_video_group |
|
|
|
self.task_weight = task_weight |
|
self.task_size = task_size |
|
|
|
if not hasattr(torch.nn.Module, 'named_buffers'): |
|
printlog('registering named_buffers for nn.Module at DistModule') |
|
torch.nn.Module.named_buffers = named_buffers |
|
|
|
broadcast_params_unihcpv2(self, self.task_grp, self.share_backbone_group, |
|
self.share_decoder_group, self.share_rgb_group, |
|
self.share_dense_labeling_group, self.share_sparse_labeling_group, |
|
self.share_text_group, self.share_video_group, |
|
ignore_bcast, |
|
share_modality_group=self.share_modality_group, |
|
) |
|
|
|
assert sync, "Currently, only sync model is supported!" |
|
if not sync: |
|
self._grad_accs = {} |
|
self._reduce_hooks = {} |
|
self._register_hooks() |
|
|
|
def forward(self, *inputs, **kwargs): |
|
return self.module(*inputs, **kwargs) |
|
|
|
def train(self, mode=True): |
|
super(DistModule_Hulk, self).train(mode) |
|
self.module.train(mode) |
|
|
|
def _register_hooks(self): |
|
for i,(name,p) in enumerate(self.named_parameters()): |
|
if p.requires_grad: |
|
p_tmp = p.expand_as(p) |
|
grad_acc = p_tmp.grad_fn.next_functions[0][0] |
|
self._reduce_hooks[name] = grad_acc.register_hook(self._make_hook(name, p, i)) |
|
self._grad_accs[name] = grad_acc |
|
|
|
def _make_hook(self, name, p, i): |
|
if not p.task_specific: |
|
def hook(*ignore): |
|
allreduce_async(name, p.grad.data) |
|
else: |
|
printlog('{} register hook as task specific'.format(name)) |
|
def hook(*ignore): |
|
|
|
allreduce(p.grad.data, group_idx=self.task_grp) |
|
return hook |
|
|
|
|
|
def reduce_gradients(self, task_specific=False): |
|
if self.sync: |
|
if not task_specific: |
|
if self.task_grp is not None or self.share_backbone_group is not None \ |
|
or self.share_decoder_group is not None or self.share_rgb_group is not None \ |
|
or self.share_dense_labeling_group is not None or self.share_text_group is not None \ |
|
or self.share_sparese_labeling_group is not None or self.share_video_group is not None\ |
|
or self.share_modality_group is not None: |
|
for name, param in self.named_parameters(): |
|
if param.grad is None: param.grad = param.data * 0 |
|
if param.task_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.task_grp) |
|
elif param.modality_share and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_modality_group) |
|
elif param.backbone_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_backbone_group) |
|
elif param.rgb_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_rgb_group) |
|
elif param.dense_labeling_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_dense_labeling_group) |
|
elif param.sparse_labeling_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_sparse_labeling_group) |
|
elif param.text_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_text_group) |
|
elif param.video_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_video_group) |
|
elif param.decoder_specific and param.requires_grad: |
|
allreduce(param.grad.data, group_idx=self.share_decoder_group) |
|
elif param.requires_grad: |
|
allreduce(param.grad.data) |
|
else: |
|
for param in self.parameters(): |
|
if param.requires_grad and param.grad is not None: |
|
dist.all_reduce(param.grad.data) |
|
else: |
|
for name, param in self.named_parameters(): |
|
if param.requires_grad and param.grad is not None: |
|
dist.all_reduce(param.grad.data, group_idx=self.task_grp) |
|
|
|
def allreduce(x, group_idx=None, ): |
|
if group_idx == 0: |
|
group_idx = None |
|
return dist.all_reduce(x, group=group_idx) |
|
|
|
def allreduce_async(name, x, group_idx=None, ): |
|
if group_idx == 0: |
|
group_idx = None |
|
return dist.all_reduce(x, group=group_idx) |
|
|
|
def broadcast_params(model, task_grp, ignore): |
|
""" broadcast model parameters """ |
|
if task_grp is not None: |
|
for name,p in model.named_parameters(): |
|
if ignore and name in ignore: |
|
printlog('param {} ignored in broadcast'.format(name)) |
|
continue |
|
try: |
|
if not p.task_specific: |
|
broadcast(p, 0) |
|
else: |
|
printlog('broadcasting task-specific param {}'.format(name)) |
|
broadcast(p, 0, group_idx=task_grp) |
|
except: |
|
raise RuntimeError('param {} does not have task_specific'.format(name)) |
|
|
|
for name,b in model.named_buffers(): |
|
if ignore and name in ignore: |
|
printlog('buffer {} ignored in broadcast'.format(name)) |
|
continue |
|
try: |
|
if not b.task_specific: |
|
broadcast(b, 0) |
|
else: |
|
printlog('broadcasting task-specific buffer {}'.format(name)) |
|
broadcast(b, 0, group_idx=task_grp) |
|
except: |
|
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
|
else: |
|
for name,p in model.named_parameters(): |
|
if ignore and name in ignore: |
|
printlog('param {} ignored in broadcast'.format(name)) |
|
continue |
|
try: |
|
broadcast(p, 0) |
|
except: |
|
raise RuntimeError('param {} does not have task_specific'.format(name)) |
|
|
|
for name,b in model.named_buffers(): |
|
if ignore and name in ignore: |
|
printlog('buffer {} ignored in broadcast'.format(name)) |
|
continue |
|
try: |
|
broadcast(b, 0) |
|
except: |
|
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
|
|
|
def broadcast(x, root, group_idx=None): |
|
if group_idx == 0: |
|
group_idx = None |
|
elif group_idx is not None: |
|
return group_idx.broadcast(x, 0) |
|
return dist.broadcast(x, root, group_idx) |
|
|
|
|
|
def broadcast_params_multitask(model, task_grp, share_backbone_group, share_neck_group, share_decoder_group, |
|
share_adapter_group, ignore): |
|
""" broadcast multi-task model parameters """ |
|
if task_grp is not None or share_backbone_group is not None \ |
|
or share_neck_group is not None or share_decoder_group is None or share_adapter_group is not None: |
|
|
|
for name,p in model.named_parameters(): |
|
if ignore and name in ignore: |
|
printlog('param {} ignored in broadcast'.format(name)) |
|
continue |
|
|
|
assert p.task_specific + p.backbone_specific + p.neck_specific + p.decoder_specific <= 1.5, \ |
|
"param could not be task_specific, backbone_specific, neck_specific, decoder_specific at same time" |
|
|
|
try: |
|
if p.task_specific: |
|
start_rank = -task_grp.rank() + dist.get_rank() |
|
printlog(f'broadcasting task-specific param {name}\tgroup_idx={task_grp}') |
|
broadcast(p, start_rank, group_idx=task_grp) |
|
elif p.adapter_specific: |
|
start_rank = -share_adapter_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting adapter-specific param {name}\tgroup_idx={share_adapter_group}') |
|
broadcast(p, start_rank, group_idx=share_adapter_group) |
|
elif p.backbone_specific: |
|
start_rank = -share_backbone_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting backbone-specific param {name}\tgroup_idx={share_backbone_group}') |
|
broadcast(p, start_rank, group_idx=share_backbone_group) |
|
elif p.neck_specific: |
|
start_rank = -share_neck_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting neck-specific param {name}\tgroup_idx={share_neck_group}') |
|
broadcast(p, start_rank, group_idx=share_neck_group) |
|
elif p.decoder_specific: |
|
start_rank = -share_decoder_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting decoder-specific param {name}\tgroup_idx={share_decoder_group}') |
|
broadcast(p, start_rank, group_idx=share_decoder_group) |
|
else: |
|
printlog(f'broadcasting non-specific param {name}') |
|
broadcast(p, 0) |
|
except: |
|
import pdb;pdb.set_trace() |
|
raise RuntimeError('param {} does not have task_specific or backbone_specific or neck_specific or decoder_specific'.format(name)) |
|
|
|
for name,b in model.named_buffers(): |
|
if ignore and name in ignore: |
|
printlog('buffer {} ignored in broadcast'.format(name)) |
|
continue |
|
assert b.task_specific + b.backbone_specific + b.neck_specific + b.decoder_specific + b.adapter_specific <= 1, \ |
|
"buffer could not be task_specific, backbone_specific, neck_specific, decoder_specific at same time" |
|
try: |
|
if b.task_specific: |
|
start_rank = -task_grp.rank() + dist.get_rank() |
|
printlog('broadcasting task-specific buffer {}'.format(name)) |
|
broadcast(b, start_rank, group_idx=task_grp) |
|
elif b.adapter_specific: |
|
start_rank = -share_adapter_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting adapter-specific param {name}\tgroup_idx={share_adapter_group}') |
|
broadcast(b, start_rank, group_idx=share_adapter_group) |
|
elif b.backbone_specific: |
|
start_rank = -share_backbone_group.rank() + dist.get_rank() |
|
printlog('broadcasting backbone-specific buffer {}'.format(name)) |
|
broadcast(b, start_rank, group_idx=share_backbone_group) |
|
elif b.neck_specific: |
|
start_rank = -share_neck_group.rank() + dist.get_rank() |
|
printlog('broadcasting neck-specific buffer {}'.format(name)) |
|
broadcast(b, start_rank, group_idx=share_neck_group) |
|
elif b.decoder_specific: |
|
start_rank = -share_decoder_group.rank() + dist.get_rank() |
|
printlog('broadcasting decoder-specific buffer {}'.format(name)) |
|
broadcast(b, start_rank, group_idx=share_decoder_group) |
|
else: |
|
dist.broadcast(b, 0) |
|
except: |
|
import pdb; pdb.set_trace() |
|
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
|
else: |
|
for name,p in model.named_parameters(): |
|
if ignore and name in ignore: |
|
printlog('param {} ignored in broadcast'.format(name)) |
|
continue |
|
try: |
|
dist.broadcast(p, 0) |
|
except: |
|
raise RuntimeError('param {} does not have task_specific'.format(name)) |
|
|
|
for name,b in model.named_buffers(): |
|
if ignore and name in ignore: |
|
printlog('buffer {} ignored in broadcast'.format(name)) |
|
continue |
|
try: |
|
dist.broadcast(b, 0) |
|
except: |
|
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
|
|
|
def broadcast_params_unihcpv2(model, task_grp, share_backbone_group, share_decoder_group, |
|
share_rgb_group, share_dense_labeling_group, share_sparse_labeling_group, |
|
share_text_group, share_video_group, ignore, |
|
share_modality_group=None): |
|
""" broadcast multi-task model parameters """ |
|
if task_grp is not None or share_backbone_group is not None \ |
|
or share_decoder_group is None or share_rgb_group is not None \ |
|
or share_dense_labeling_group is not None or share_sparse_labeling_group is not None \ |
|
or share_text_group is not None or share_video_group is not None or share_modality_group is not None: |
|
|
|
for name,p in model.named_parameters(): |
|
if ignore and name in ignore: |
|
printlog('param {} ignored in broadcast'.format(name)) |
|
continue |
|
|
|
assert p.task_specific + p.modality_share + p.backbone_specific + p.decoder_specific + p.rgb_specific + \ |
|
p.dense_labeling_specific + p.sparse_labeling_specific + p.text_specific + \ |
|
p.video_specific <= 1.5, \ |
|
"param could not be task_specific, backbone_specific, decoder_specific, modality_specific at same time" |
|
|
|
try: |
|
if p.task_specific: |
|
start_rank = -task_grp.rank() + dist.get_rank() |
|
printlog(f'broadcasting task-specific param {name}\tgroup_idx={task_grp}') |
|
broadcast(p, start_rank, group_idx=task_grp) |
|
elif p.modality_share: |
|
start_rank = -share_modality_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting modality-share param {name}\tgroup_idx={share_modality_group}') |
|
broadcast(p, start_rank, group_idx=share_modality_group) |
|
elif p.backbone_specific: |
|
start_rank = -share_backbone_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting backbone-specific param {name}\tgroup_idx={share_backbone_group}') |
|
broadcast(p, start_rank, group_idx=share_backbone_group) |
|
elif p.decoder_specific: |
|
start_rank = -share_decoder_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting decoder-specific param {name}\tgroup_idx={share_decoder_group}') |
|
broadcast(p, start_rank, group_idx=share_decoder_group) |
|
elif p.rgb_specific: |
|
start_rank = -share_rgb_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting rgb-specific param {name}\tgroup_idx={share_rgb_group}') |
|
broadcast(p, start_rank, group_idx=share_rgb_group) |
|
elif p.dense_labeling_specific: |
|
start_rank = -share_dense_labeling_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting dense_labeling-specific param {name}\tgroup_idx={share_dense_labeling_group}') |
|
broadcast(p, start_rank, group_idx=share_dense_labeling_group) |
|
elif p.sparse_labeling_specific: |
|
start_rank = -share_sparse_labeling_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting sparse_labeling-specific param {name}\tgroup_idx={share_sparse_labeling_group}') |
|
broadcast(p, start_rank, group_idx=share_sparse_labeling_group) |
|
elif p.text_specific: |
|
start_rank = -share_text_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting text-specific param {name}\tgroup_idx={share_text_group}') |
|
broadcast(p, start_rank, group_idx=share_text_group) |
|
elif p.video_specific: |
|
start_rank = -share_video_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting video-specific param {name}\tgroup_idx={share_video_group}') |
|
broadcast(p, start_rank, group_idx=share_rgb_group) |
|
else: |
|
printlog(f'broadcasting non-specific param {name}') |
|
broadcast(p, 0) |
|
except: |
|
import pdb;pdb.set_trace() |
|
raise RuntimeError('param {} does not have task_specific or backbone_specific or neck_specific or decoder_specific'.format(name)) |
|
|
|
for name,b in model.named_buffers(): |
|
if ignore and name in ignore: |
|
printlog('buffer {} ignored in broadcast'.format(name)) |
|
continue |
|
assert b.task_specific + b.modality_share + b.backbone_specific + b.decoder_specific + b.rgb_specific + \ |
|
b.dense_labeling_specific + b.sparse_labeling_specific + b.text_specific + \ |
|
b.video_specific <= 1.5, \ |
|
"buffer could not be task_specific, backbone_specific, decoder_specific, modality_specific at same time" |
|
try: |
|
if b.task_specific: |
|
start_rank = -task_grp.rank() + dist.get_rank() |
|
printlog('broadcasting task-specific buffer {}'.format(name)) |
|
broadcast(b, start_rank, group_idx=task_grp) |
|
elif b.modality_share: |
|
start_rank = -share_modality_group.rank() + dist.get_rank() |
|
printlog('broadcasting modality-share buffer {}'.format(name)) |
|
broadcast(b, start_rank, group_idx=share_modality_group) |
|
elif b.backbone_specific: |
|
start_rank = -share_backbone_group.rank() + dist.get_rank() |
|
printlog('broadcasting backbone-specific buffer {}'.format(name)) |
|
broadcast(b, start_rank, group_idx=share_backbone_group) |
|
elif b.decoder_specific: |
|
start_rank = -share_decoder_group.rank() + dist.get_rank() |
|
printlog('broadcasting decoder-specific buffer {}'.format(name)) |
|
broadcast(b, start_rank, group_idx=share_decoder_group) |
|
elif b.rgb_specific: |
|
start_rank = -share_rgb_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting rgb-specific param {name}\tgroup_idx={share_rgb_group}') |
|
broadcast(b, start_rank, group_idx=share_rgb_group) |
|
elif b.dense_labeling_specific: |
|
start_rank = -share_dense_labeling_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting dense_labeling-specific param {name}\tgroup_idx={share_dense_labeling_group}') |
|
broadcast(b, start_rank, group_idx=share_dense_labeling_group) |
|
elif b.sparse_labeling_specific: |
|
start_rank = -share_sparse_labeling_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting sparse_labeling-specific param {name}\tgroup_idx={share_sparse_labeling_group}') |
|
broadcast(b, start_rank, group_idx=share_sparse_labeling_group) |
|
elif b.text_specific: |
|
start_rank = -share_text_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting text-specific param {name}\tgroup_idx={share_text_group}') |
|
broadcast(b, start_rank, group_idx=share_text_group) |
|
elif b.video_specific: |
|
start_rank = -share_video_group.rank() + dist.get_rank() |
|
printlog(f'broadcasting video-specific param {name}\tgroup_idx={share_video_group}') |
|
broadcast(b, start_rank, group_idx=share_rgb_group) |
|
else: |
|
dist.broadcast(b, 0) |
|
except: |
|
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
|
else: |
|
for name,p in model.named_parameters(): |
|
if ignore and name in ignore: |
|
printlog('param {} ignored in broadcast'.format(name)) |
|
continue |
|
try: |
|
dist.broadcast(p, 0) |
|
except: |
|
raise RuntimeError('param {} does not have task_specific'.format(name)) |
|
|
|
for name,b in model.named_buffers(): |
|
if ignore and name in ignore: |
|
printlog('buffer {} ignored in broadcast'.format(name)) |
|
continue |
|
try: |
|
dist.broadcast(b, 0) |
|
except: |
|
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
|
|
|
|
|
def find_free_port(): |
|
s = socket.socket() |
|
s.bind(('', 0)) |
|
return s.getsockname()[1] |
|
|
|
def dist_init(): |
|
import socket |
|
import time |
|
hostname = socket.gethostname() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
|
if int(os.environ["RANK"]) == 0: |
|
print('this task is not running on cluster!') |
|
rank = int(os.environ["RANK"]) |
|
world_size = int(os.environ['WORLD_SIZE']) |
|
gpu = int(os.environ['LOCAL_RANK']) |
|
dist_url = 'env://' |
|
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) |
|
addr = socket.gethostname() |
|
|
|
elif 'SLURM_PROCID' in os.environ: |
|
proc_id = int(os.environ['SLURM_PROCID']) |
|
if proc_id == 0: |
|
print('Init dist using slurm!') |
|
print("Job Id is {} on {} ".format(os.environ["SLURM_JOBID"], os.environ['SLURM_NODELIST'])) |
|
ntasks = int(os.environ['SLURM_NTASKS']) |
|
node_list = os.environ['SLURM_NODELIST'] |
|
num_gpus = torch.cuda.device_count() |
|
addr = subprocess.getoutput( |
|
'scontrol show hostname {} | head -n1'.format(node_list)) |
|
jobid = os.environ["SLURM_JOBID"] |
|
hostfile = "dist_url_" + jobid + ".txt" |
|
if proc_id == 0: |
|
tcp_port = str(find_free_port()) |
|
print('write port {} to file: {} '.format(tcp_port, hostfile)) |
|
with open(hostfile, "w") as f: |
|
f.write(tcp_port) |
|
else: |
|
print('read port from file: {}'.format(hostfile)) |
|
while not os.path.exists(hostfile): |
|
time.sleep(1) |
|
time.sleep(2) |
|
with open(hostfile, "r") as f: |
|
tcp_port = f.read() |
|
|
|
os.environ['MASTER_PORT'] = str(tcp_port) |
|
os.environ['MASTER_ADDR'] = addr |
|
os.environ['WORLD_SIZE'] = str(ntasks) |
|
os.environ['RANK'] = str(proc_id) |
|
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) |
|
os.environ['LOCAL_SIZE'] = str(num_gpus) |
|
dist_url = 'env://' |
|
world_size = ntasks |
|
rank = proc_id |
|
gpu = proc_id % num_gpus |
|
else: |
|
print('Not using distributed mode') |
|
distributed = False |
|
return |
|
torch.cuda.set_device(gpu) |
|
dist_backend = 'nccl' |
|
|
|
|
|
print('rank: {} addr: {} port: {}'.format(rank, addr, os.environ['MASTER_PORT'])) |
|
torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url, |
|
world_size=world_size, rank=rank) |
|
torch.distributed.barrier() |
|
if 'SLURM_PROCID' in os.environ and rank == 0: |
|
if os.path.isfile(hostfile): |
|
os.remove(hostfile) |
|
if world_size >= 1: |
|
|
|
assert comm_._LOCAL_PROCESS_GROUP is None |
|
num_gpus = torch.cuda.device_count() |
|
num_machines = world_size // num_gpus |
|
for i in range(num_machines): |
|
ranks_on_i = list(range(i * num_gpus, (i + 1) * num_gpus)) |
|
print('new_group: {}'.format(ranks_on_i)) |
|
pg = torch.distributed.new_group(ranks_on_i) |
|
if rank in ranks_on_i: |
|
|
|
comm_._LOCAL_PROCESS_GROUP = pg |
|
|
|
return rank, world_size |
|
|
|
class DistributedGivenIterationSampler(Sampler): |
|
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1, |
|
shuffle_strategy=0, random_seed=0, imageNumPerClass=4, ret_save_path=None): |
|
if world_size is None: |
|
world_size = dist.get_world_size() |
|
if rank is None: |
|
rank = dist.get_rank() |
|
assert rank < world_size |
|
sync_print('sampler: rank={}, world_size={}, random_seed={}'.format(rank, world_size, random_seed)) |
|
self.dataset = dataset |
|
self.total_iter = total_iter |
|
self.batch_size = batch_size |
|
self.world_size = world_size |
|
self.rank = rank |
|
self.last_iter = last_iter |
|
self.shuffle_strategy = shuffle_strategy |
|
self.random_seed = random_seed |
|
self.imageNumPerClass = imageNumPerClass |
|
self.ret_save_path = ret_save_path |
|
self.task_name = self.dataset.task_name |
|
|
|
self.total_size = self.total_iter*self.batch_size |
|
|
|
self.call = 0 |
|
|
|
|
|
if self.ret_save_path is not None: |
|
self.this_ret_path = os.path.join(self.ret_save_path, '_'.join([self.task_name, str(self.world_size), str(self.rank)]) + ".pth.tar") |
|
if os.path.exists(self.this_ret_path): |
|
ret_file = torch.load(self.this_ret_path) |
|
|
|
if ret_file['task_name'] == self.task_name and ret_file['task_size'] == self.world_size and ret_file['task_rank'] == self.rank: |
|
printlog(" load task sampler from ------> {}".format(self.this_ret_path)) |
|
self.indices = ret_file['ret_file'] |
|
self.dataset.received_indices = True |
|
return |
|
else: |
|
printlog("sampler file ({}) is not existed, and will be generated now--->".format(self.this_ret_path)) |
|
|
|
if self.shuffle_strategy in [0,1,3,4,6]: |
|
self.indices = self.gen_new_list() |
|
self.dataset.indices = self.indices |
|
self.dataset.received_indices = True |
|
elif self.shuffle_strategy == 2: |
|
self.indices = self.gen_s2() |
|
elif self.shuffle_strategy == 5: |
|
self.indices = self.gen_s5() |
|
else: |
|
raise Error("Invalid shuffle_strategy!") |
|
|
|
if self.ret_save_path is not None and not os.path.exists(self.ret_save_path): |
|
self.save() |
|
|
|
def gen_s2(self): |
|
|
|
np.random.seed(self.rank) |
|
|
|
indices = [] |
|
labels = self.dataset.labels |
|
printlog('using shuffle strategy 2, initializing class map...') |
|
class2id = collections.OrderedDict() |
|
for i,l in enumerate(labels): |
|
if l in class2id: |
|
class2id[l].append(i) |
|
else: |
|
class2id[l] = [i] |
|
keys = list(class2id.keys()) |
|
|
|
np.random.shuffle(keys) |
|
num_class = len(keys) |
|
printlog('class map done.') |
|
|
|
for i in range((self.last_iter+1)*self.batch_size, self.total_size): |
|
class_id = np.random.randint(0, num_class) |
|
this_num = len(class2id[keys[class_id]]) |
|
inner_id = np.random.randint(0, this_num) |
|
|
|
indices.append(class2id[keys[class_id]][inner_id]) |
|
|
|
return indices |
|
|
|
def gen_s5(self): |
|
np.random.seed(self.rank) |
|
|
|
indices = [] |
|
labels = self.dataset.labels |
|
printlog('using shuffle strategy 5, initializing class map...') |
|
class2id = collections.OrderedDict() |
|
for i,l in enumerate(labels): |
|
if l in class2id: |
|
class2id[l].append(i) |
|
else: |
|
class2id[l] = [i] |
|
keys = list(class2id.keys()) |
|
|
|
np.random.shuffle(keys) |
|
num_class = len(keys) |
|
printlog('class map done.') |
|
printlog('{} class with {} samples in a batch!'.format(self.batch_size//self.imageNumPerClass, self.imageNumPerClass)) |
|
|
|
for i in range((self.last_iter+1)*self.batch_size, self.total_size): |
|
if i % self.imageNumPerClass == 0: |
|
class_id = np.random.randint(0, num_class) |
|
this_num = len(class2id[keys[class_id]]) |
|
inner_id = np.random.randint(0, this_num) |
|
|
|
indices.append(class2id[keys[class_id]][inner_id]) |
|
|
|
return indices |
|
|
|
def __iter__(self): |
|
if self.call == 0: |
|
self.call = 1 |
|
return iter(self.indices) |
|
|
|
|
|
|
|
|
|
else: |
|
raise RuntimeError("this sampler is not designed to be called more than once!!") |
|
|
|
def gen_new_list(self): |
|
|
|
|
|
if self.shuffle_strategy == 0: |
|
|
|
np.random.seed(self.rank) |
|
|
|
indices = np.arange(len(self.dataset)) |
|
indices = indices[:self.total_size] |
|
num_repeat = (self.total_size-1) // indices.shape[0] + 1 |
|
indices = np.tile(indices, num_repeat) |
|
indices = indices[:self.total_size] |
|
|
|
for beg in range(0, self.total_size, len(self.dataset)): |
|
end = min(beg+len(self.dataset), self.total_size) |
|
np.random.shuffle(indices[beg:end]) |
|
|
|
|
|
elif self.shuffle_strategy == 1: |
|
|
|
np.random.seed(self.random_seed) |
|
|
|
all_size = self.total_size * self.world_size |
|
indices = np.arange(len(self.dataset)) |
|
indices = indices[:all_size] |
|
num_repeat = (all_size-1) // indices.shape[0] + 1 |
|
indices = np.tile(indices, num_repeat) |
|
indices = indices[:all_size] |
|
|
|
np.random.shuffle(indices) |
|
beg = self.total_size * self.rank |
|
indices = indices[beg:beg+self.total_size] |
|
|
|
elif self.shuffle_strategy == 3: |
|
|
|
np.random.seed(0) |
|
|
|
all_size = self.total_size * self.world_size |
|
|
|
|
|
labels = self.dataset.labels |
|
class2id = collections.OrderedDict() |
|
for i,l in enumerate(labels): |
|
if l in class2id: |
|
class2id[l].append(i) |
|
else: |
|
class2id[l] = [i] |
|
|
|
class_count = [len(x) for _,x in class2id.items()] |
|
mean_num = int(np.mean(class_count)) |
|
|
|
indices = [] |
|
for _,v in class2id.items(): |
|
if len(v) < mean_num: |
|
lack_num = mean_num - len(v) |
|
indices.extend(np.random.choice(v, lack_num)) |
|
indices.extend(v) |
|
indices = np.array(indices) |
|
indices = indices[:all_size] |
|
printlog('using strategy 3, mean_num: {}, origin_len: {}, balanced_len: {}'.format(mean_num, len(self.dataset), len(indices))) |
|
|
|
num_repeat = (all_size-1) // indices.shape[0] + 1 |
|
indices = np.tile(indices, num_repeat) |
|
indices = indices[:all_size] |
|
|
|
np.random.shuffle(indices) |
|
beg = self.total_size * self.rank |
|
indices = indices[beg:beg+self.total_size] |
|
|
|
elif self.shuffle_strategy == 6: |
|
|
|
np.random.seed(self.random_seed) |
|
|
|
labels = self.dataset.labels |
|
printlog('using shuffle strategy 6, initializing class map...') |
|
class2id = collections.defaultdict(list) |
|
for i,l in enumerate(labels): |
|
class2id[l].append(i) |
|
|
|
mini_indices = [] |
|
for pid, idxs in class2id.items(): |
|
if len(idxs) < self.imageNumPerClass: |
|
idxs = idxs + list(np.random.choice(idxs, size=self.imageNumPerClass-len(idxs), replace=True)) |
|
elif len(idxs) % self.imageNumPerClass != 0: |
|
add_num = int(len(idxs) // self.imageNumPerClass + 1) * self.imageNumPerClass - len(idxs) |
|
idxs = idxs + list(np.random.choice(idxs, size=add_num, replace=True)) |
|
|
|
assert len(idxs) % self.imageNumPerClass == 0 |
|
mini_indices.extend([idxs[i:i+self.imageNumPerClass] for i in range(0, len(idxs), self.imageNumPerClass)]) |
|
|
|
np.random.shuffle(mini_indices) |
|
indices = np.array(mini_indices).reshape(-1) |
|
|
|
all_size = self.total_size * self.world_size |
|
indices = indices[:all_size] |
|
num_repeat = (all_size-1) // indices.shape[0] + 1 |
|
indices = np.tile(indices, num_repeat) |
|
indices = indices[:all_size] |
|
|
|
beg = self.total_size * self.rank |
|
indices = indices[beg:beg+self.total_size] |
|
elif self.shuffle_strategy == 7: |
|
|
|
np.random.seed(self.random_seed) |
|
|
|
labels = self.dataset.labels |
|
printlog('using shuffle strategy 7, initializing class map...') |
|
class2id = collections.defaultdict(list) |
|
for i,l in enumerate(labels): |
|
class2id[l].append(i) |
|
|
|
mini_indices = [] |
|
for pid, idxs in class2id.items(): |
|
if len(idxs) < self.imageNumPerClass: |
|
idxs = idxs + list(np.random.choice(idxs, size=self.imageNumPerClass-len(idxs), replace=True)) |
|
elif len(idxs) % self.imageNumPerClass != 0: |
|
add_num = int(len(idxs) // self.imageNumPerClass + 1) * self.imageNumPerClass - len(idxs) |
|
idxs = idxs + list(np.random.choice(idxs, size=add_num, replace=True)) |
|
|
|
assert len(idxs) % self.imageNumPerClass == 0 |
|
mini_indices.extend([idxs[i:i+self.imageNumPerClass] for i in range(0, len(idxs), self.imageNumPerClass)]) |
|
|
|
np.random.shuffle(mini_indices) |
|
indices = np.array(mini_indices).reshape(-1) |
|
|
|
all_size = self.total_size * self.world_size |
|
indices = indices[:all_size] |
|
num_repeat = (all_size-1) // indices.shape[0] + 1 |
|
|
|
for repeat in range(num_repeat)-1: |
|
for pid, idxs in class2id.items(): |
|
if len(idxs) < self.imageNumPerClass: |
|
idxs = idxs + list(np.random.choice(idxs, size=self.imageNumPerClass-len(idxs), replace=True)) |
|
elif len(idxs) % self.imageNumPerClass != 0: |
|
add_num = int(len(idxs) // self.imageNumPerClass + 1) * self.imageNumPerClass - len(idxs) |
|
idxs = idxs + list(np.random.choice(idxs, size=add_num, replace=True)) |
|
|
|
assert len(idxs) % self.imageNumPerClass == 0 |
|
mini_indices.extend([idxs[i:i+self.imageNumPerClass] for i in range(0, len(idxs), self.imageNumPerClass)]) |
|
|
|
np.random.shuffle(mini_indices) |
|
indices = np.array(mini_indices).reshape(-1) |
|
|
|
all_size = self.total_size * self.world_size |
|
indices = indices[:all_size] |
|
|
|
beg = self.total_size * self.rank |
|
indices = indices[beg:beg+self.total_size] |
|
elif self.shuffle_strategy == 8: |
|
|
|
np.random.seed(self.random_seed) |
|
|
|
labels = self.dataset.labels |
|
print('using shuffle strategy 8, initializing class map...') |
|
class2id = collections.defaultdict(list) |
|
for i,l in enumerate(labels): |
|
class2id[l].append(i) |
|
|
|
pids = set() |
|
cls_pids_idxs = collections.defaultdict(list) |
|
for pid, idxs in class2id.items(): |
|
if len(idxs) < self.imageNumPerClass: |
|
idxs = idxs + list(np.random.choice(idxs, size=self.imageNumPerClass-len(idxs), replace=True)) |
|
elif len(idxs) % self.imageNumPerClass != 0: |
|
add_num = int(len(idxs) // self.imageNumPerClass + 1) * self.imageNumPerClass - len(idxs) |
|
idxs = idxs + list(np.random.choice(idxs, size=add_num, replace=True)) |
|
|
|
assert len(idxs) % self.imageNumPerClass == 0 |
|
np.random.shuffle(idxs) |
|
idxs = [idxs[i:(i+self.imageNumPerClass)] for i in range(0, len(idxs), self.imageNumPerClass)] |
|
ptr = 0 |
|
values = [ptr, idxs] |
|
cls_pids_idxs[pid] = values |
|
pids.add(pid) |
|
|
|
indices = [] |
|
classnum_per_batch = self.batch_size // self.imageNumPerClass |
|
while len(pids) >= classnum_per_batch: |
|
batch = [] |
|
sub_pids = [] |
|
for i, pid in enumerate(pids): |
|
if i >= classnum_per_batch: |
|
break |
|
sub_pids.append(pid) |
|
|
|
for pid in sub_pids: |
|
ptr = cls_pids_idxs[pid][0] |
|
idxs = cls_pids_idxs[pid][1] |
|
batch.extend(idxs[ptr]) |
|
if ptr + 1 >= len(idxs): |
|
pids.remove(pid) |
|
else: |
|
cls_pids_idxs[pid][0] += 1 |
|
indices.append(batch) |
|
|
|
np.random.shuffle(indices) |
|
indices = np.array(indices).reshape(-1) |
|
|
|
all_size = self.total_size * self.world_size |
|
indices = indices[:all_size] |
|
num_repeat = (all_size-1) // indices.shape[0] + 1 |
|
indices = np.tile(indices, num_repeat) |
|
indices = indices[:all_size] |
|
|
|
beg = self.total_size * self.rank |
|
indices = indices[beg:beg+self.total_size] |
|
else: |
|
raise RuntimeError('unknow shuffle strategy') |
|
|
|
assert len(indices) == self.total_size |
|
|
|
return indices[(self.last_iter+1)*self.batch_size:] |
|
|
|
def __len__(self): |
|
return self.total_size - (self.last_iter+1)*self.batch_size |
|
|
|
def save(self): |
|
torch.save({'task_name': self.task_name, |
|
'task_size': self.world_size, |
|
'task_rank': self.rank, |
|
'ret_file': self.indices}, self.this_ret_path) |
|
printlog("save sampler file ------> {}".format(self.this_ret_path)) |
|
|
|
|
|
class RandomIdentitySampler(Sampler): |
|
""" |
|
Randomly sample N identities, then for each identity, |
|
randomly sample K instances, therefore batch size is N*K. |
|
Args: |
|
- data_source (list): list of (img_path, pid, camid). |
|
- imageNumPerClass (int): number of instances per identity in a batch. |
|
- batch_size (int): number of examples in a batch. |
|
""" |
|
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1, |
|
shuffle_strategy=0, random_seed=0, imageNumPerClass=4, ret_save_path=None): |
|
self.batch_size = batch_size |
|
self.world_size = world_size if world_size is not None else dist.get_world_size() |
|
self.num_instances = imageNumPerClass |
|
self.num_pids_per_batch = self.batch_size // self.num_instances |
|
self.index_dic = defaultdict(list) |
|
self.random_seed = random_seed |
|
self.total_iter = total_iter |
|
self.total_size = self.total_iter*self.batch_size |
|
self.last_iter = last_iter |
|
self.dataset = dataset |
|
|
|
labels = self.dataset.labels |
|
printlog('using RandomIdentityBatchSampler, initializing class map...') |
|
self.index_dic = collections.defaultdict(list) |
|
for i,l in enumerate(labels): |
|
self.index_dic[l].append(i) |
|
|
|
self.pids = np.array(list(self.index_dic.keys())) |
|
self.rank = rank if rank is not None else dist.get_rank() |
|
|
|
def __iter__(self): |
|
np.random.seed(self.random_seed) |
|
self._seed = int(self.random_seed) |
|
final_idxs = self.sample_list() |
|
length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) |
|
|
|
final_idxs = self.__fetch_current_node_idxs(final_idxs, length) |
|
return iter(final_idxs) |
|
|
|
|
|
def __fetch_current_node_idxs(self, final_idxs, length): |
|
total_num = len(final_idxs) |
|
block_num = (length // self.batch_size) |
|
index_target = [] |
|
for i in range(0, block_num * self.world_size, self.world_size): |
|
index = range(self.batch_size * self.rank + self.batch_size * i, min(self.batch_size * self.rank + self.batch_size * (i+1), total_num)) |
|
index_target.extend(index) |
|
index_target_npy = np.array(index_target) |
|
final_idxs = list(np.array(final_idxs)[index_target_npy]) |
|
return final_idxs |
|
|
|
|
|
def batch_sample_list(self): |
|
|
|
avai_pids = copy.deepcopy(self.pids.tolist()) |
|
batch_idxs_dict = {} |
|
batch_indices = [] |
|
while len(avai_pids) >= self.num_pids_per_batch: |
|
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() |
|
for pid in selected_pids: |
|
if pid not in batch_idxs_dict or len(batch_idxs_dict[pid]) < self.num_instances: |
|
idxs = copy.deepcopy(self.index_dic[pid]) |
|
if len(idxs) < self.num_instances: |
|
idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() |
|
np.random.shuffle(idxs) |
|
batch_idxs_dict[pid] = idxs |
|
|
|
avai_idxs = batch_idxs_dict[pid] |
|
for _ in range(self.num_instances): |
|
batch_indices.append(avai_idxs.pop(0)) |
|
|
|
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) |
|
|
|
return batch_indices |
|
|
|
def sample_list(self): |
|
|
|
all_size = self.total_size * self.world_size |
|
|
|
all_indices = list() |
|
|
|
while len(all_indices) <= all_size: |
|
all_indices.extend(self.batch_sample_list()) |
|
|
|
return all_indices[:all_size] |
|
|
|
def __len__(self): |
|
return self.total_size - (self.last_iter+1)*self.batch_size |
|
|
|
|
|
class RandomIdentityBatchSampler(Sampler): |
|
""" |
|
Randomly sample N identities, then for each identity, |
|
randomly sample K instances, therefore batch size is N*K. |
|
Args: |
|
- data_source (list): list of (img_path, pid, camid). |
|
- imageNumPerClass (int): number of instances per identity in a batch. |
|
- batch_size (int): number of examples in a batch. |
|
""" |
|
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1, |
|
shuffle_strategy=0, random_seed=0, imageNumPerClass=4, ret_save_path=None): |
|
self.batch_size = batch_size |
|
self.world_size = world_size if world_size is not None else dist.get_world_size() |
|
self.num_instances = imageNumPerClass |
|
self.num_pids_per_batch = self.batch_size // self.num_instances |
|
self.index_dic = defaultdict(list) |
|
self.random_seed = random_seed |
|
self.total_iter = total_iter |
|
self.total_size = self.total_iter*self.batch_size |
|
self.last_iter = last_iter |
|
self.dataset = dataset |
|
|
|
labels = self.dataset.labels |
|
printlog('using RandomIdentityBatchSampler, initializing class map...') |
|
self.index_dic = collections.defaultdict(list) |
|
for i,l in enumerate(labels): |
|
self.index_dic[l].append(i) |
|
|
|
self.pids = np.array(list(self.index_dic.keys())) |
|
self.rank = rank if rank is not None else dist.get_rank() |
|
|
|
def __iter__(self): |
|
np.random.seed(self.random_seed) |
|
self._seed = int(self.random_seed) |
|
final_idxs_batches = self.sample_list() |
|
final_idxs = self.__fetch_current_node_idxs(final_idxs_batches) |
|
return iter(final_idxs) |
|
|
|
|
|
def __fetch_current_node_idxs(self, final_idxs_batches): |
|
res = [] |
|
for final_idxs in final_idxs_batches: |
|
total_num = len(final_idxs) |
|
length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) |
|
block_num = (length // self.batch_size) |
|
index_target = [] |
|
for i in range(0, block_num * self.world_size, self.world_size): |
|
index = range(self.batch_size * self.rank + self.batch_size * i, min(self.batch_size * self.rank + self.batch_size * (i+1), total_num)) |
|
index_target.extend(index) |
|
index_target_npy = np.array(index_target) |
|
final_idxs = list(np.array(final_idxs)[index_target_npy]) |
|
res.extend(final_idxs) |
|
res = res[:self.total_size] |
|
return res |
|
|
|
|
|
def batch_sample_list(self): |
|
|
|
avai_pids = copy.deepcopy(self.pids.tolist()) |
|
batch_idxs_dict = {} |
|
batch_indices = [] |
|
while len(avai_pids) >= self.num_pids_per_batch: |
|
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() |
|
for pid in selected_pids: |
|
if pid not in batch_idxs_dict or len(batch_idxs_dict[pid]) < self.num_instances: |
|
idxs = copy.deepcopy(self.index_dic[pid]) |
|
if len(idxs) < self.num_instances: |
|
idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() |
|
np.random.shuffle(idxs) |
|
batch_idxs_dict[pid] = idxs |
|
|
|
avai_idxs = batch_idxs_dict[pid] |
|
for _ in range(self.num_instances): |
|
batch_indices.append(avai_idxs.pop(0)) |
|
|
|
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) |
|
|
|
return batch_indices |
|
|
|
def sample_list(self): |
|
|
|
all_size = self.total_size * self.world_size |
|
|
|
all_indices = list() |
|
all_indices_batch = list() |
|
|
|
while len(all_indices) <= all_size: |
|
all_indices.extend(self.batch_sample_list()) |
|
all_indices_batch.append(self.batch_sample_list()) |
|
|
|
return all_indices_batch |
|
|
|
def __len__(self): |
|
return self.total_size - (self.last_iter+1)*self.batch_size |
|
|
|
|
|
class RandomIdentityEpochBatchSampler(Sampler): |
|
""" |
|
Randomly sample N identities, then for each identity, |
|
randomly sample K instances, therefore batch size is N*K. |
|
Args: |
|
- data_source (list): list of (img_path, pid, camid). |
|
- imageNumPerClass (int): number of instances per identity in a batch. |
|
- batch_size (int): number of examples in a batch. |
|
""" |
|
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1, |
|
shuffle_strategy=0, random_seed=0, imageNumPerClass=4, ret_save_path=None): |
|
self.batch_size = batch_size |
|
self.world_size = world_size if world_size is not None else dist.get_world_size() |
|
self.num_instances = imageNumPerClass |
|
self.num_pids_per_batch = self.batch_size // self.num_instances |
|
self.index_dic = defaultdict(list) |
|
self.random_seed = random_seed |
|
self.total_iter = total_iter |
|
self.total_size = self.total_iter*self.batch_size |
|
self.last_iter = last_iter |
|
self.dataset = dataset |
|
|
|
labels = self.dataset.labels |
|
printlog('using RandomIdentityBatchSampler, initializing class map...') |
|
self.index_dic = collections.defaultdict(list) |
|
for i,l in enumerate(labels): |
|
self.index_dic[l].append(i) |
|
|
|
self.pids = np.array(list(self.index_dic.keys())) |
|
self.rank = rank if rank is not None else dist.get_rank() |
|
|
|
|
|
self.length = 0 |
|
for pid in self.pids: |
|
idxs = self.index_dic[pid] |
|
num = len(idxs) |
|
if num < self.num_instances: |
|
num = self.num_instances |
|
self.length += num - num % self.num_instances |
|
self.length //= self.world_size |
|
|
|
def __iter__(self): |
|
np.random.seed(self.random_seed) |
|
self._seed = int(self.random_seed) |
|
final_idxs_batches = self.sample_list() |
|
length = int(math.ceil(len(final_idxs_batches) * 1.0 / self.world_size)) |
|
final_idxs = self.__fetch_current_node_idxs(final_idxs_batches, length) |
|
self.length = len(final_idxs) |
|
return iter(final_idxs) |
|
|
|
def set_epoch(self, epoch): |
|
self.random_seed = self.random_seed + epoch |
|
|
|
def __fetch_current_node_idxs(self, final_idxs, length): |
|
total_num = len(final_idxs) |
|
block_num = (length // self.batch_size) |
|
index_target = [] |
|
for i in range(0, block_num * self.world_size, self.world_size): |
|
index = range(self.batch_size * self.rank + self.batch_size * i, min(self.batch_size * self.rank + self.batch_size * (i+1), total_num)) |
|
index_target.extend(index) |
|
index_target_npy = np.array(index_target) |
|
final_idxs = list(np.array(final_idxs)[index_target_npy]) |
|
return final_idxs |
|
|
|
def sample_list(self): |
|
|
|
avai_pids = copy.deepcopy(self.pids.tolist()) |
|
batch_idxs_dict = {} |
|
batch_indices = [] |
|
while len(avai_pids) >= self.num_pids_per_batch: |
|
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() |
|
for pid in selected_pids: |
|
if pid not in batch_idxs_dict or len(batch_idxs_dict[pid]) < self.num_instances: |
|
idxs = copy.deepcopy(self.index_dic[pid]) |
|
if len(idxs) < self.num_instances: |
|
idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() |
|
np.random.shuffle(idxs) |
|
batch_idxs_dict[pid] = idxs |
|
|
|
avai_idxs = batch_idxs_dict[pid] |
|
for _ in range(self.num_instances): |
|
batch_indices.append(avai_idxs.pop(0)) |
|
|
|
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) |
|
|
|
return batch_indices |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
|
|
class DistributedGivenSizeSampler(Sampler): |
|
def __init__(self, dataset, given_size=None, dup_shuffle=False, world_size=None, rank=None): |
|
if world_size is None: |
|
world_size = dist.get_world_size() |
|
if rank is None: |
|
rank = dist.get_rank() |
|
assert rank < world_size |
|
self.dataset = dataset |
|
self.dup_shuffle = dup_shuffle |
|
self.world_size = world_size |
|
self.rank = rank |
|
self.epoch = 0 |
|
if given_size is None: |
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.world_size)) |
|
else: |
|
self.num_samples = int(math.ceil(given_size * 1.0 / self.world_size)) |
|
self.total_size = self.num_samples * self.world_size |
|
|
|
if self.dup_shuffle: |
|
self.offset = 0 |
|
self.g = torch.Generator() |
|
|
|
self.indices = self.gen_new_list(self.g) |
|
|
|
def __iter__(self): |
|
if self.dup_shuffle: |
|
if self.offset == self.world_size: |
|
self.indices = self.gen_new_list(self.g) |
|
self.offset = 0 |
|
beg = self.offset*self.num_samples |
|
indices = self.indices[beg:beg+self.num_samples] |
|
self.offset += 1 |
|
else: |
|
|
|
g = torch.Generator() |
|
g.manual_seed(self.epoch) |
|
indices = self.gen_new_list(g) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offset = self.num_samples * self.rank |
|
indices = indices[offset:offset + self.num_samples] |
|
|
|
assert len(indices) == self.num_samples |
|
return iter(indices) |
|
|
|
def gen_new_list(self, g): |
|
origin_indices = list(torch.randperm(len(self.dataset), generator=g)) |
|
indices = origin_indices[:] |
|
|
|
|
|
indices = indices[:self.total_size] |
|
extra = self.total_size - len(origin_indices) |
|
while self.total_size - len(indices) > 0: |
|
intake = self.total_size - len(indices) |
|
indices += origin_indices[:intake] |
|
assert len(indices) == self.total_size |
|
return indices |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|
|
class DistributedGroupSampler(Sampler): |
|
"""Sampler that restricts data loading to a subset of the dataset. |
|
|
|
It is especially useful in conjunction with |
|
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each |
|
process can pass a DistributedSampler instance as a DataLoader sampler, |
|
and load a subset of the original dataset that is exclusive to it. |
|
|
|
.. note:: |
|
Dataset is assumed to be of constant size. |
|
|
|
Arguments: |
|
dataset: Dataset used for sampling. |
|
num_replicas (optional): Number of processes participating in |
|
distributed training. |
|
rank (optional): Rank of the current process within num_replicas. |
|
seed (int, optional): random seed used to shuffle the sampler if |
|
``shuffle=True``. This number should be identical across all |
|
processes in the distributed group. Default: 0. |
|
""" |
|
|
|
def __init__(self, |
|
dataset, |
|
batch_size=1, |
|
num_replicas=None, |
|
world_size=None, |
|
rank=None, |
|
total_iter=-1, |
|
random_seed=0, |
|
last_iter=-1): |
|
|
|
_rank, _num_replicas = rank, world_size |
|
if num_replicas is None: |
|
num_replicas = _num_replicas |
|
if rank is None: |
|
rank = _rank |
|
self.dataset = dataset |
|
self.samples_per_gpu = batch_size |
|
self.num_replicas = num_replicas |
|
self.rank = rank |
|
self.seed = random_seed if random_seed is not None else 0 |
|
self.total_iter = total_iter |
|
self.batch_size = batch_size |
|
|
|
assert hasattr(self.dataset, 'flag') |
|
dataset_flag = self.dataset.flag |
|
self.flag = np.tile(dataset_flag, (total_iter - 1) * batch_size // len(self.dataset) + 1)[:total_iter * batch_size] |
|
self.flag = self.flag[:total_iter*self.batch_size] |
|
|
|
self.group_sizes = np.bincount(self.flag) |
|
|
|
self.num_samples = 0 |
|
for i, j in enumerate(self.group_sizes): |
|
self.num_samples += int( |
|
math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / |
|
self.num_replicas)) * self.samples_per_gpu |
|
|
|
self.total_size = self.num_samples * self.num_replicas |
|
|
|
def __iter__(self): |
|
|
|
g = torch.Generator() |
|
g.manual_seed(self.seed) |
|
|
|
indices = [] |
|
for i, size in enumerate(self.group_sizes): |
|
if size > 0: |
|
indice = np.where(self.flag == i)[0] |
|
assert len(indice) == size |
|
|
|
|
|
|
|
indice = indice[list( |
|
torch.randperm(int(size), generator=g).numpy())].tolist() |
|
extra = int( |
|
math.ceil( |
|
size * 1.0 / self.samples_per_gpu / self.num_replicas) |
|
) * self.samples_per_gpu * self.num_replicas - len(indice) |
|
|
|
tmp = indice.copy() |
|
for _ in range(extra // size): |
|
indice.extend(tmp) |
|
indice.extend(tmp[:extra % size]) |
|
indices.extend(indice) |
|
|
|
assert len(indices) == self.total_size |
|
|
|
indices = [ |
|
indices[j] for i in list( |
|
torch.randperm( |
|
len(indices) // self.samples_per_gpu, generator=g)) |
|
for j in range(i * self.samples_per_gpu, (i + 1) * |
|
self.samples_per_gpu) |
|
] |
|
|
|
indices = [i % len(self.dataset.flag) for i in indices] |
|
|
|
|
|
offset = self.num_samples * self.rank |
|
indices = indices[offset:offset + self.num_samples] |
|
assert len(indices) == self.num_samples |
|
|
|
return iter(indices) |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
class BatchSampler(Sampler): |
|
r"""Wraps another sampler to yield a mini-batch of indices. |
|
|
|
Args: |
|
sampler (Sampler or Iterable): Base sampler. Can be any iterable object |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If ``True``, the sampler will drop the last batch if |
|
its size would be less than ``batch_size`` |
|
|
|
Example: |
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] |
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
|
""" |
|
|
|
def __init__(self, sampler, batch_size, drop_last): |
|
|
|
|
|
|
|
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ |
|
batch_size <= 0: |
|
raise ValueError("batch_size should be a positive integer value, " |
|
"but got batch_size={}".format(batch_size)) |
|
if not isinstance(drop_last, bool): |
|
raise ValueError("drop_last should be a boolean value, but got " |
|
"drop_last={}".format(drop_last)) |
|
self.sampler = sampler |
|
self.batch_size = batch_size |
|
self.drop_last = drop_last |
|
self.flag = self.sampler.flag |
|
|
|
def __iter__(self): |
|
ret = [] |
|
batch = [] |
|
|
|
for idx in self.sampler: |
|
batch.append(idx) |
|
if len(batch) == self.batch_size: |
|
batch = tuple(batch) |
|
for i in batch: |
|
if self.flag[i]!=self.flag[batch[0]]: |
|
from IPython import embed;embed() |
|
ret.extend(batch) |
|
batch = [] |
|
if len(batch) > 0 and not self.drop_last: |
|
ret.extend(batch) |
|
|
|
return iter(ret) |
|
|
|
def __len__(self): |
|
|
|
|
|
|
|
|
|
if self.drop_last: |
|
return len(self.sampler) // self.batch_size |
|
else: |
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size |
|
|
|
class DistributedSequentialSampler(Sampler): |
|
def __init__(self, dataset, world_size=None, rank=None): |
|
if world_size == None: |
|
world_size = dist.get_world_size() |
|
if rank == None: |
|
rank = dist.get_rank() |
|
self.dataset = dataset |
|
self.world_size = world_size |
|
self.rank = rank |
|
assert len(self.dataset) >= self.world_size, f'{len(self.dataset)} vs {self.world_size}' |
|
sub_num = int(math.ceil(len(self.dataset) * 1.0 / self.world_size)) |
|
self.beg = sub_num * self.rank |
|
self.end = min(self.beg+sub_num, len(self.dataset)) |
|
|
|
def __iter__(self): |
|
indices = list(range(self.beg, self.end)) |
|
return iter(indices) |
|
|
|
def __len__(self): |
|
return self.end - self.beg |
|
|
|
def bcast_value(value): |
|
v = torch.Tensor([value]) |
|
dist.broadcast(v, root=0) |
|
return v.item() |
|
|
|
def gather_tensors(input_array): |
|
world_size = dist.get_world_size() |
|
rank = dist.get_rank() |
|
|
|
myshape = input_array.shape |
|
mycount = input_array.size |
|
shape_tensor = torch.Tensor(np.array(myshape)) |
|
all_shape = [torch.Tensor(np.array(myshape)) for i in range(world_size)] |
|
dist.gather(all_shape, shape_tensor, root=0) |
|
|
|
if rank == 0: |
|
all_shape = [x.numpy() for x in all_shape] |
|
all_count = [int(x.prod()) for x in all_shape] |
|
all_shape = [list(map(int, x)) for x in all_shape] |
|
max_count = max(all_count) |
|
else: |
|
max_count = 0 |
|
max_count = int(bcast_value(max_count)) |
|
|
|
output_tensors = [torch.Tensor(max_count) for i in range(world_size)] |
|
padded_input_array = np.zeros(max_count) |
|
padded_input_array[:mycount] = input_array.reshape(-1) |
|
input_tensor = torch.Tensor(padded_input_array) |
|
dist.gather(output_tensors, input_tensor, root=0) |
|
|
|
if rank == 0: |
|
padded_output = [x.numpy() for x in output_tensors] |
|
output = [x[:all_count[i]].reshape(all_shape[i]) for i,x in enumerate(padded_output)] |
|
else: |
|
output = None |
|
|
|
return output |
|
|
|
|
|
|
|
def allgatherv(input_tensor, flatten=False): |
|
world_size = dist.get_world_size() |
|
rank = dist.get_rank() |
|
|
|
myshape = torch.Tensor([*input_tensor.size()]).long() |
|
mycount = input_tensor.numel() |
|
all_shape = [torch.zeros_like(myshape, dtype=torch.long) for i in range(world_size)] |
|
dist.all_gather(all_shape, myshape) |
|
|
|
all_count = [int(x.prod()) for x in all_shape] |
|
max_count = max(all_count) |
|
|
|
output_tensors = [torch.zeros(max_count).to(input_tensor) for i in range(world_size)] |
|
padded_input = torch.zeros(max_count).to(input_tensor) |
|
padded_input[:mycount] = input_tensor.view(-1) |
|
dist.all_gather(output_tensors, padded_input) |
|
|
|
if flatten: |
|
output = torch.cat([x[:all_count[i]] for i,x in enumerate(output_tensors)]) |
|
else: |
|
output = [x[:all_count[i]].view(*all_shape[i]) for i,x in enumerate(output_tensors)] |
|
|
|
return output |
|
|
|
def simple_group_split(world_size, rank, num_groups): |
|
groups = [] |
|
rank_list = np.split(np.arange(world_size), num_groups) |
|
rank_list = [list(map(int, x)) for x in rank_list] |
|
for i in range(num_groups): |
|
groups.append(dist.new_group(ranks=rank_list[i])) |
|
group_size = world_size // num_groups |
|
return groups[rank//group_size] |
|
|
|
def specific_group_split(world_size, rank, group_spec): |
|
|
|
assert type(group_spec) is list |
|
assert all(map(lambda x: type(x) is int, group_spec)) |
|
num_groups = len(group_spec) |
|
splits = np.sum(group_spec) |
|
assert world_size % splits == 0 |
|
unit = int(world_size / splits) |
|
|
|
group_spec = [x*unit for x in group_spec] |
|
groups = [] |
|
roots = [] |
|
last = 0 |
|
group_info = edict() |
|
for i,gs in enumerate(group_spec): |
|
ranks = list(map(int, np.arange(last, last+gs))) |
|
groups.append(dist.new_group(ranks=ranks)) |
|
roots.append(last) |
|
if rank in ranks: |
|
group_info.group = groups[-1] |
|
group_info.task_size = gs |
|
group_info.task_id = i |
|
group_info.task_sub_id = rank - last |
|
group_info.task_root = last |
|
last += gs |
|
group_info.task_roots = roots |
|
group_info.num_groups = num_groups |
|
return group_info |
|
|
|
def vreduce(x, tensor, group=None): |
|
y = tensor.clone() |
|
if group is not None: |
|
dist.all_reduce(y, group=group) |
|
else: |
|
dist.all_reduce(y) |
|
x.update(y.item()) |
|
|
|
def vgather(x_list, x): |
|
|
|
dist.all_gather(x_list,torch.Tensor([x]).cuda()) |
|
|
|
|
|
|
|
def reduce_dict(input_dict, task_size, task_rank, group=None, average=True): |
|
""" |
|
Args: |
|
input_dict (dict): all the values will be reduced |
|
average (bool): whether to do average or sum |
|
Reduce the values in the dictionary from all processes so that all processes |
|
have the averaged results. Returns a dict with the same fields as |
|
input_dict, after reduction. |
|
""" |
|
world_size = task_size |
|
if world_size < 2: |
|
return input_dict |
|
with torch.no_grad(): |
|
names = [] |
|
values = [] |
|
|
|
for k in sorted(input_dict.keys()): |
|
names.append(k) |
|
values.append(input_dict[k]) |
|
values = torch.stack(values, dim=0) |
|
dist.all_reduce(values, group=group) |
|
if average: |
|
values /= world_size |
|
reduced_dict = {k: v for k, v in zip(names, values)} |
|
return reduced_dict |
|
|