import os |
import torch |
import collections |
from easydict import EasyDict as edict |
import torch.distributed as dist |
from torch.nn import Module |
from torch.utils.data.sampler import Sampler |
import math |
import numpy as np |
import multiprocessing as mp |
import copy |
import random |
from collections import defaultdict |
from core.utils import named_buffers, sync_print, printlog |
import subprocess |
import socket |
from . import comm_ |
import torch.cuda.comm |
class DistModule(torch.nn.Module): |
def __init__(self, module, sync=False, task_grp=None, share_backbone_group=None, \ |
share_neck_group=None, share_decoder_group=None, share_adapter_group=None, |
ignore_bcast=None, \ |
task_weight=None, task_size=None): |
super(DistModule, self).__init__() |
self.module = module |
self.sync = sync |
self.task_grp = task_grp |
self.share_backbone_group = share_backbone_group |
self.share_neck_group = share_neck_group |
self.share_decoder_group = share_decoder_group |
self.share_adapter_group = share_adapter_group |
self.task_weight = task_weight |
self.task_size = task_size |
if not hasattr(torch.nn.Module, 'named_buffers'): |
printlog('registering named_buffers for nn.Module at DistModule') |
torch.nn.Module.named_buffers = named_buffers |
broadcast_params_multitask(self, self.task_grp, self.share_backbone_group, |
self.share_neck_group, self.share_decoder_group, |
self.share_adapter_group, ignore_bcast) |
assert sync, "Currently, only sync model is supported!" |
if not sync: |
self._grad_accs = {} |
self._reduce_hooks = {} |
self._register_hooks() |
def forward(self, *inputs, **kwargs): |
return self.module(*inputs, **kwargs) |
def train(self, mode=True): |
super(DistModule, self).train(mode) |
self.module.train(mode) |
def _register_hooks(self): |
for i,(name,p) in enumerate(self.named_parameters()): |
if p.requires_grad: |
p_tmp = p.expand_as(p) |
grad_acc = p_tmp.grad_fn.next_functions[0][0] |
self._reduce_hooks[name] = grad_acc.register_hook(self._make_hook(name, p, i)) |
self._grad_accs[name] = grad_acc |
def _make_hook(self, name, p, i): |
if not p.task_specific: |
def hook(*ignore): |
allreduce_async(name, p.grad.data) |
else: |
printlog('{} register hook as task specific'.format(name)) |
def hook(*ignore): |
allreduce(p.grad.data, group_idx=self.task_grp) |
return hook |
def reduce_gradients(self, task_specific=False): |
if self.sync: |
if not task_specific: |
if self.task_grp is not None or self.share_backbone_group is not None \ |
or self.share_neck_group is not None or self.share_decoder_group is not None: |
for name, param in self.named_parameters(): |
if param.grad is None: param.grad = param.data * 0 |
if param.task_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.task_grp) |
elif param.backbone_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_backbone_group) |
elif param.adapter_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_adapter_group) |
elif param.neck_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_neck_group) |
elif param.decoder_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_decoder_group) |
elif param.requires_grad: |
allreduce(param.grad.data) |
else: |
for param in self.parameters(): |
if param.requires_grad and param.grad is not None: |
dist.all_reduce(param.grad.data) |
else: |
for name, param in self.named_parameters(): |
if param.requires_grad and param.grad is not None: |
dist.all_reduce(param.grad.data, group_idx=self.task_grp) |
class DistModule_Hulk(torch.nn.Module): |
modality_names =['rgb', 'dense_labeling', 'text'] |
def __init__(self, module, sync=False, task_grp=None, share_backbone_group=None, |
share_decoder_group=None, share_rgb_group=None, share_dense_labeling_group=None, |
share_sparse_labeling_group=None, share_text_group=None, share_video_group=None, |
share_modality_group=None, |
ignore_bcast=None, task_weight=None, task_size=None, ): |
super(DistModule_Hulk, self).__init__() |
self.module = module |
self.sync = sync |
self.task_grp = task_grp |
self.share_modality_group = share_modality_group |
self.share_backbone_group = share_backbone_group |
self.share_decoder_group = share_decoder_group |
self.share_rgb_group = share_rgb_group |
self.share_dense_labeling_group = share_dense_labeling_group |
self.share_sparse_labeling_group = share_sparse_labeling_group |
self.share_text_group = share_text_group |
self.share_video_group = share_video_group |
self.task_weight = task_weight |
self.task_size = task_size |
if not hasattr(torch.nn.Module, 'named_buffers'): |
printlog('registering named_buffers for nn.Module at DistModule') |
torch.nn.Module.named_buffers = named_buffers |
broadcast_params_unihcpv2(self, self.task_grp, self.share_backbone_group, |
self.share_decoder_group, self.share_rgb_group, |
self.share_dense_labeling_group, self.share_sparse_labeling_group, |
self.share_text_group, self.share_video_group, |
ignore_bcast, |
share_modality_group=self.share_modality_group, |
) |
assert sync, "Currently, only sync model is supported!" |
if not sync: |
self._grad_accs = {} |
self._reduce_hooks = {} |
self._register_hooks() |
def forward(self, *inputs, **kwargs): |
return self.module(*inputs, **kwargs) |
def train(self, mode=True): |
super(DistModule_Hulk, self).train(mode) |
self.module.train(mode) |
def _register_hooks(self): |
for i,(name,p) in enumerate(self.named_parameters()): |
if p.requires_grad: |
p_tmp = p.expand_as(p) |
grad_acc = p_tmp.grad_fn.next_functions[0][0] |
self._reduce_hooks[name] = grad_acc.register_hook(self._make_hook(name, p, i)) |
self._grad_accs[name] = grad_acc |
def _make_hook(self, name, p, i): |
if not p.task_specific: |
def hook(*ignore): |
allreduce_async(name, p.grad.data) |
else: |
printlog('{} register hook as task specific'.format(name)) |
def hook(*ignore): |
allreduce(p.grad.data, group_idx=self.task_grp) |
return hook |
def reduce_gradients(self, task_specific=False): |
if self.sync: |
if not task_specific: |
if self.task_grp is not None or self.share_backbone_group is not None \ |
or self.share_decoder_group is not None or self.share_rgb_group is not None \ |
or self.share_dense_labeling_group is not None or self.share_text_group is not None \ |
or self.share_sparese_labeling_group is not None or self.share_video_group is not None\ |
or self.share_modality_group is not None: |
for name, param in self.named_parameters(): |
if param.grad is None: param.grad = param.data * 0 |
if param.task_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.task_grp) |
elif param.modality_share and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_modality_group) |
elif param.backbone_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_backbone_group) |
elif param.rgb_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_rgb_group) |
elif param.dense_labeling_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_dense_labeling_group) |
elif param.sparse_labeling_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_sparse_labeling_group) |
elif param.text_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_text_group) |
elif param.video_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_video_group) |
elif param.decoder_specific and param.requires_grad: |
allreduce(param.grad.data, group_idx=self.share_decoder_group) |
elif param.requires_grad: |
allreduce(param.grad.data) |
else: |
for param in self.parameters(): |
if param.requires_grad and param.grad is not None: |
dist.all_reduce(param.grad.data) |
else: |
for name, param in self.named_parameters(): |
if param.requires_grad and param.grad is not None: |
dist.all_reduce(param.grad.data, group_idx=self.task_grp) |
def allreduce(x, group_idx=None, ): |
if group_idx == 0: |
group_idx = None |
return dist.all_reduce(x, group=group_idx) |
def allreduce_async(name, x, group_idx=None, ): |
if group_idx == 0: |
group_idx = None |
return dist.all_reduce(x, group=group_idx) |
def broadcast_params(model, task_grp, ignore): |
""" broadcast model parameters """ |
if task_grp is not None: |
for name,p in model.named_parameters(): |
if ignore and name in ignore: |
printlog('param {} ignored in broadcast'.format(name)) |
continue |
try: |
if not p.task_specific: |
broadcast(p, 0) |
else: |
printlog('broadcasting task-specific param {}'.format(name)) |
broadcast(p, 0, group_idx=task_grp) |
except: |
raise RuntimeError('param {} does not have task_specific'.format(name)) |
for name,b in model.named_buffers(): |
if ignore and name in ignore: |
printlog('buffer {} ignored in broadcast'.format(name)) |
continue |
try: |
if not b.task_specific: |
broadcast(b, 0) |
else: |
printlog('broadcasting task-specific buffer {}'.format(name)) |
broadcast(b, 0, group_idx=task_grp) |
except: |
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
else: |
for name,p in model.named_parameters(): |
if ignore and name in ignore: |
printlog('param {} ignored in broadcast'.format(name)) |
continue |
try: |
broadcast(p, 0) |
except: |
raise RuntimeError('param {} does not have task_specific'.format(name)) |
for name,b in model.named_buffers(): |
if ignore and name in ignore: |
printlog('buffer {} ignored in broadcast'.format(name)) |
continue |
try: |
broadcast(b, 0) |
except: |
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
def broadcast(x, root, group_idx=None): |
if group_idx == 0: |
group_idx = None |
elif group_idx is not None: |
return group_idx.broadcast(x, 0) |
return dist.broadcast(x, root, group_idx) |
def broadcast_params_multitask(model, task_grp, share_backbone_group, share_neck_group, share_decoder_group, |
share_adapter_group, ignore): |
""" broadcast multi-task model parameters """ |
if task_grp is not None or share_backbone_group is not None \ |
or share_neck_group is not None or share_decoder_group is None or share_adapter_group is not None: |
for name,p in model.named_parameters(): |
if ignore and name in ignore: |
printlog('param {} ignored in broadcast'.format(name)) |
continue |
assert p.task_specific + p.backbone_specific + p.neck_specific + p.decoder_specific <= 1.5, \ |
"param could not be task_specific, backbone_specific, neck_specific, decoder_specific at same time" |
try: |
if p.task_specific: |
start_rank = -task_grp.rank() + dist.get_rank() |
printlog(f'broadcasting task-specific param {name}\tgroup_idx={task_grp}') |
broadcast(p, start_rank, group_idx=task_grp) |
elif p.adapter_specific: |
start_rank = -share_adapter_group.rank() + dist.get_rank() |
printlog(f'broadcasting adapter-specific param {name}\tgroup_idx={share_adapter_group}') |
broadcast(p, start_rank, group_idx=share_adapter_group) |
elif p.backbone_specific: |
start_rank = -share_backbone_group.rank() + dist.get_rank() |
printlog(f'broadcasting backbone-specific param {name}\tgroup_idx={share_backbone_group}') |
broadcast(p, start_rank, group_idx=share_backbone_group) |
elif p.neck_specific: |
start_rank = -share_neck_group.rank() + dist.get_rank() |
printlog(f'broadcasting neck-specific param {name}\tgroup_idx={share_neck_group}') |
broadcast(p, start_rank, group_idx=share_neck_group) |
elif p.decoder_specific: |
start_rank = -share_decoder_group.rank() + dist.get_rank() |
printlog(f'broadcasting decoder-specific param {name}\tgroup_idx={share_decoder_group}') |
broadcast(p, start_rank, group_idx=share_decoder_group) |
else: |
printlog(f'broadcasting non-specific param {name}') |
broadcast(p, 0) |
except: |
import pdb;pdb.set_trace() |
raise RuntimeError('param {} does not have task_specific or backbone_specific or neck_specific or decoder_specific'.format(name)) |
for name,b in model.named_buffers(): |
if ignore and name in ignore: |
printlog('buffer {} ignored in broadcast'.format(name)) |
continue |
assert b.task_specific + b.backbone_specific + b.neck_specific + b.decoder_specific + b.adapter_specific <= 1, \ |
"buffer could not be task_specific, backbone_specific, neck_specific, decoder_specific at same time" |
try: |
if b.task_specific: |
start_rank = -task_grp.rank() + dist.get_rank() |
printlog('broadcasting task-specific buffer {}'.format(name)) |
broadcast(b, start_rank, group_idx=task_grp) |
elif b.adapter_specific: |
start_rank = -share_adapter_group.rank() + dist.get_rank() |
printlog(f'broadcasting adapter-specific param {name}\tgroup_idx={share_adapter_group}') |
broadcast(b, start_rank, group_idx=share_adapter_group) |
elif b.backbone_specific: |
start_rank = -share_backbone_group.rank() + dist.get_rank() |
printlog('broadcasting backbone-specific buffer {}'.format(name)) |
broadcast(b, start_rank, group_idx=share_backbone_group) |
elif b.neck_specific: |
start_rank = -share_neck_group.rank() + dist.get_rank() |
printlog('broadcasting neck-specific buffer {}'.format(name)) |
broadcast(b, start_rank, group_idx=share_neck_group) |
elif b.decoder_specific: |
start_rank = -share_decoder_group.rank() + dist.get_rank() |
printlog('broadcasting decoder-specific buffer {}'.format(name)) |
broadcast(b, start_rank, group_idx=share_decoder_group) |
else: |
dist.broadcast(b, 0) |
except: |
import pdb; pdb.set_trace() |
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
else: |
for name,p in model.named_parameters(): |
if ignore and name in ignore: |
printlog('param {} ignored in broadcast'.format(name)) |
continue |
try: |
dist.broadcast(p, 0) |
except: |
raise RuntimeError('param {} does not have task_specific'.format(name)) |
for name,b in model.named_buffers(): |
if ignore and name in ignore: |
printlog('buffer {} ignored in broadcast'.format(name)) |
continue |
try: |
dist.broadcast(b, 0) |
except: |
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
def broadcast_params_unihcpv2(model, task_grp, share_backbone_group, share_decoder_group, |
share_rgb_group, share_dense_labeling_group, share_sparse_labeling_group, |
share_text_group, share_video_group, ignore, |
share_modality_group=None): |
""" broadcast multi-task model parameters """ |
if task_grp is not None or share_backbone_group is not None \ |
or share_decoder_group is None or share_rgb_group is not None \ |
or share_dense_labeling_group is not None or share_sparse_labeling_group is not None \ |
or share_text_group is not None or share_video_group is not None or share_modality_group is not None: |
for name,p in model.named_parameters(): |
if ignore and name in ignore: |
printlog('param {} ignored in broadcast'.format(name)) |
continue |
assert p.task_specific + p.modality_share + p.backbone_specific + p.decoder_specific + p.rgb_specific + \ |
p.dense_labeling_specific + p.sparse_labeling_specific + p.text_specific + \ |
p.video_specific <= 1.5, \ |
"param could not be task_specific, backbone_specific, decoder_specific, modality_specific at same time" |
try: |
if p.task_specific: |
start_rank = -task_grp.rank() + dist.get_rank() |
printlog(f'broadcasting task-specific param {name}\tgroup_idx={task_grp}') |
broadcast(p, start_rank, group_idx=task_grp) |
elif p.modality_share: |
start_rank = -share_modality_group.rank() + dist.get_rank() |
printlog(f'broadcasting modality-share param {name}\tgroup_idx={share_modality_group}') |
broadcast(p, start_rank, group_idx=share_modality_group) |
elif p.backbone_specific: |
start_rank = -share_backbone_group.rank() + dist.get_rank() |
printlog(f'broadcasting backbone-specific param {name}\tgroup_idx={share_backbone_group}') |
broadcast(p, start_rank, group_idx=share_backbone_group) |
elif p.decoder_specific: |
start_rank = -share_decoder_group.rank() + dist.get_rank() |
printlog(f'broadcasting decoder-specific param {name}\tgroup_idx={share_decoder_group}') |
broadcast(p, start_rank, group_idx=share_decoder_group) |
elif p.rgb_specific: |
start_rank = -share_rgb_group.rank() + dist.get_rank() |
printlog(f'broadcasting rgb-specific param {name}\tgroup_idx={share_rgb_group}') |
broadcast(p, start_rank, group_idx=share_rgb_group) |
elif p.dense_labeling_specific: |
start_rank = -share_dense_labeling_group.rank() + dist.get_rank() |
printlog(f'broadcasting dense_labeling-specific param {name}\tgroup_idx={share_dense_labeling_group}') |
broadcast(p, start_rank, group_idx=share_dense_labeling_group) |
elif p.sparse_labeling_specific: |
start_rank = -share_sparse_labeling_group.rank() + dist.get_rank() |
printlog(f'broadcasting sparse_labeling-specific param {name}\tgroup_idx={share_sparse_labeling_group}') |
broadcast(p, start_rank, group_idx=share_sparse_labeling_group) |
elif p.text_specific: |
start_rank = -share_text_group.rank() + dist.get_rank() |
printlog(f'broadcasting text-specific param {name}\tgroup_idx={share_text_group}') |
broadcast(p, start_rank, group_idx=share_text_group) |
elif p.video_specific: |
start_rank = -share_video_group.rank() + dist.get_rank() |
printlog(f'broadcasting video-specific param {name}\tgroup_idx={share_video_group}') |
broadcast(p, start_rank, group_idx=share_rgb_group) |
else: |
printlog(f'broadcasting non-specific param {name}') |
broadcast(p, 0) |
except: |
import pdb;pdb.set_trace() |
raise RuntimeError('param {} does not have task_specific or backbone_specific or neck_specific or decoder_specific'.format(name)) |
for name,b in model.named_buffers(): |
if ignore and name in ignore: |
printlog('buffer {} ignored in broadcast'.format(name)) |
continue |
assert b.task_specific + b.modality_share + b.backbone_specific + b.decoder_specific + b.rgb_specific + \ |
b.dense_labeling_specific + b.sparse_labeling_specific + b.text_specific + \ |
b.video_specific <= 1.5, \ |
"buffer could not be task_specific, backbone_specific, decoder_specific, modality_specific at same time" |
try: |
if b.task_specific: |
start_rank = -task_grp.rank() + dist.get_rank() |
printlog('broadcasting task-specific buffer {}'.format(name)) |
broadcast(b, start_rank, group_idx=task_grp) |
elif b.modality_share: |
start_rank = -share_modality_group.rank() + dist.get_rank() |
printlog('broadcasting modality-share buffer {}'.format(name)) |
broadcast(b, start_rank, group_idx=share_modality_group) |
elif b.backbone_specific: |
start_rank = -share_backbone_group.rank() + dist.get_rank() |
printlog('broadcasting backbone-specific buffer {}'.format(name)) |
broadcast(b, start_rank, group_idx=share_backbone_group) |
elif b.decoder_specific: |
start_rank = -share_decoder_group.rank() + dist.get_rank() |
printlog('broadcasting decoder-specific buffer {}'.format(name)) |
broadcast(b, start_rank, group_idx=share_decoder_group) |
elif b.rgb_specific: |
start_rank = -share_rgb_group.rank() + dist.get_rank() |
printlog(f'broadcasting rgb-specific param {name}\tgroup_idx={share_rgb_group}') |
broadcast(b, start_rank, group_idx=share_rgb_group) |
elif b.dense_labeling_specific: |
start_rank = -share_dense_labeling_group.rank() + dist.get_rank() |
printlog(f'broadcasting dense_labeling-specific param {name}\tgroup_idx={share_dense_labeling_group}') |
broadcast(b, start_rank, group_idx=share_dense_labeling_group) |
elif b.sparse_labeling_specific: |
start_rank = -share_sparse_labeling_group.rank() + dist.get_rank() |
printlog(f'broadcasting sparse_labeling-specific param {name}\tgroup_idx={share_sparse_labeling_group}') |
broadcast(b, start_rank, group_idx=share_sparse_labeling_group) |
elif b.text_specific: |
start_rank = -share_text_group.rank() + dist.get_rank() |
printlog(f'broadcasting text-specific param {name}\tgroup_idx={share_text_group}') |
broadcast(b, start_rank, group_idx=share_text_group) |
elif b.video_specific: |
start_rank = -share_video_group.rank() + dist.get_rank() |
printlog(f'broadcasting video-specific param {name}\tgroup_idx={share_video_group}') |
broadcast(b, start_rank, group_idx=share_rgb_group) |
else: |
dist.broadcast(b, 0) |
except: |
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
else: |
for name,p in model.named_parameters(): |
if ignore and name in ignore: |
printlog('param {} ignored in broadcast'.format(name)) |
continue |
try: |
dist.broadcast(p, 0) |
except: |
raise RuntimeError('param {} does not have task_specific'.format(name)) |
for name,b in model.named_buffers(): |
if ignore and name in ignore: |
printlog('buffer {} ignored in broadcast'.format(name)) |
continue |
try: |
dist.broadcast(b, 0) |
except: |
raise RuntimeError('buffer {} does not have task_specific'.format(name, id(b))) |
def find_free_port(): |
s = socket.socket() |
s.bind(('', 0)) |
return s.getsockname()[1] |
def dist_init(): |
import socket |
import time |
hostname = socket.gethostname() |
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
if int(os.environ["RANK"]) == 0: |
print('this task is not running on cluster!') |
rank = int(os.environ["RANK"]) |
world_size = int(os.environ['WORLD_SIZE']) |
gpu = int(os.environ['LOCAL_RANK']) |
dist_url = 'env://' |
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) |
addr = socket.gethostname() |
elif 'SLURM_PROCID' in os.environ: |
proc_id = int(os.environ['SLURM_PROCID']) |
if proc_id == 0: |
print('Init dist using slurm!') |
print("Job Id is {} on {} ".format(os.environ["SLURM_JOBID"], os.environ['SLURM_NODELIST'])) |
ntasks = int(os.environ['SLURM_NTASKS']) |
node_list = os.environ['SLURM_NODELIST'] |
num_gpus = torch.cuda.device_count() |
addr = subprocess.getoutput( |
'scontrol show hostname {} | head -n1'.format(node_list)) |
jobid = os.environ["SLURM_JOBID"] |
hostfile = "dist_url_" + jobid + ".txt" |
if proc_id == 0: |
tcp_port = str(find_free_port()) |
print('write port {} to file: {} '.format(tcp_port, hostfile)) |
with open(hostfile, "w") as f: |
f.write(tcp_port) |
else: |
print('read port from file: {}'.format(hostfile)) |
while not os.path.exists(hostfile): |
time.sleep(1) |
time.sleep(2) |
with open(hostfile, "r") as f: |
tcp_port = f.read() |
os.environ['MASTER_PORT'] = str(tcp_port) |
os.environ['MASTER_ADDR'] = addr |
os.environ['WORLD_SIZE'] = str(ntasks) |
os.environ['RANK'] = str(proc_id) |
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) |
os.environ['LOCAL_SIZE'] = str(num_gpus) |
dist_url = 'env://' |
world_size = ntasks |
rank = proc_id |
gpu = proc_id % num_gpus |
else: |
print('Not using distributed mode') |
distributed = False |
return |
torch.cuda.set_device(gpu) |
dist_backend = 'nccl' |
print('rank: {} addr: {} port: {}'.format(rank, addr, os.environ['MASTER_PORT'])) |
torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url, |
world_size=world_size, rank=rank) |
torch.distributed.barrier() |
if 'SLURM_PROCID' in os.environ and rank == 0: |
if os.path.isfile(hostfile): |
os.remove(hostfile) |
if world_size >= 1: |
assert comm_._LOCAL_PROCESS_GROUP is None |
num_gpus = torch.cuda.device_count() |
num_machines = world_size // num_gpus |
for i in range(num_machines): |
ranks_on_i = list(range(i * num_gpus, (i + 1) * num_gpus)) |
print('new_group: {}'.format(ranks_on_i)) |
pg = torch.distributed.new_group(ranks_on_i) |
if rank in ranks_on_i: |
return rank, world_size |
class DistributedGivenIterationSampler(Sampler): |
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1, |
shuffle_strategy=0, random_seed=0, imageNumPerClass=4, ret_save_path=None): |
if world_size is None: |
world_size = dist.get_world_size() |
if rank is None: |
rank = dist.get_rank() |
assert rank < world_size |
sync_print('sampler: rank={}, world_size={}, random_seed={}'.format(rank, world_size, random_seed)) |
self.dataset = dataset |
self.total_iter = total_iter |
self.batch_size = batch_size |
self.world_size = world_size |
self.rank = rank |
self.last_iter = last_iter |
self.shuffle_strategy = shuffle_strategy |
self.random_seed = random_seed |
self.imageNumPerClass = imageNumPerClass |
self.ret_save_path = ret_save_path |
self.task_name = self.dataset.task_name |
self.total_size = self.total_iter*self.batch_size |
self.call = 0 |
if self.ret_save_path is not None: |
self.this_ret_path = os.path.join(self.ret_save_path, '_'.join([self.task_name, str(self.world_size), str(self.rank)]) + ".pth.tar") |
if os.path.exists(self.this_ret_path): |
ret_file = torch.load(self.this_ret_path) |
if ret_file['task_name'] == self.task_name and ret_file['task_size'] == self.world_size and ret_file['task_rank'] == self.rank: |
printlog(" load task sampler from ------> {}".format(self.this_ret_path)) |
self.indices = ret_file['ret_file'] |
self.dataset.received_indices = True |
return |
else: |
printlog("sampler file ({}) is not existed, and will be generated now--->".format(self.this_ret_path)) |
if self.shuffle_strategy in [0,1,3,4,6]: |
self.indices = self.gen_new_list() |
self.dataset.indices = self.indices |
self.dataset.received_indices = True |
elif self.shuffle_strategy == 2: |
self.indices = self.gen_s2() |
elif self.shuffle_strategy == 5: |
self.indices = self.gen_s5() |
else: |
raise Error("Invalid shuffle_strategy!") |
if self.ret_save_path is not None and not os.path.exists(self.ret_save_path): |
self.save() |
def gen_s2(self): |
np.random.seed(self.rank) |
indices = [] |
labels = self.dataset.labels |
printlog('using shuffle strategy 2, initializing class map...') |
class2id = collections.OrderedDict() |
for i,l in enumerate(labels): |
if l in class2id: |
class2id[l].append(i) |
else: |
class2id[l] = [i] |
keys = list(class2id.keys()) |
np.random.shuffle(keys) |
num_class = len(keys) |
printlog('class map done.') |
for i in range((self.last_iter+1)*self.batch_size, self.total_size): |
class_id = np.random.randint(0, num_class) |
this_num = len(class2id[keys[class_id]]) |
inner_id = np.random.randint(0, this_num) |
indices.append(class2id[keys[class_id]][inner_id]) |
return indices |
def gen_s5(self): |
np.random.seed(self.rank) |
indices = [] |
labels = self.dataset.labels |
printlog('using shuffle strategy 5, initializing class map...') |
class2id = collections.OrderedDict() |
for i,l in enumerate(labels): |
if l in class2id: |
class2id[l].append(i) |
else: |
class2id[l] = [i] |
keys = list(class2id.keys()) |
np.random.shuffle(keys) |
num_class = len(keys) |
printlog('class map done.') |
printlog('{} class with {} samples in a batch!'.format(self.batch_size//self.imageNumPerClass, self.imageNumPerClass)) |
for i in range((self.last_iter+1)*self.batch_size, self.total_size): |
if i % self.imageNumPerClass == 0: |
class_id = np.random.randint(0, num_class) |
this_num = len(class2id[keys[class_id]]) |
inner_id = np.random.randint(0, this_num) |
indices.append(class2id[keys[class_id]][inner_id]) |
return indices |
def __iter__(self): |
if self.call == 0: |
self.call = 1 |
return iter(self.indices) |
else: |
raise RuntimeError("this sampler is not designed to be called more than once!!") |
def gen_new_list(self): |
if self.shuffle_strategy == 0: |
np.random.seed(self.rank) |
indices = np.arange(len(self.dataset)) |
indices = indices[:self.total_size] |
num_repeat = (self.total_size-1) // indices.shape[0] + 1 |
indices = np.tile(indices, num_repeat) |
indices = indices[:self.total_size] |
for beg in range(0, self.total_size, len(self.dataset)): |
end = min(beg+len(self.dataset), self.total_size) |
np.random.shuffle(indices[beg:end]) |
elif self.shuffle_strategy == 1: |
np.random.seed(self.random_seed) |
all_size = self.total_size * self.world_size |
indices = np.arange(len(self.dataset)) |
indices = indices[:all_size] |
num_repeat = (all_size-1) // indices.shape[0] + 1 |
indices = np.tile(indices, num_repeat) |
indices = indices[:all_size] |
np.random.shuffle(indices) |
beg = self.total_size * self.rank |
indices = indices[beg:beg+self.total_size] |
elif self.shuffle_strategy == 3: |
np.random.seed(0) |
all_size = self.total_size * self.world_size |
labels = self.dataset.labels |
class2id = collections.OrderedDict() |
for i,l in enumerate(labels): |
if l in class2id: |
class2id[l].append(i) |
else: |
class2id[l] = [i] |
class_count = [len(x) for _,x in class2id.items()] |
mean_num = int(np.mean(class_count)) |
indices = [] |
for _,v in class2id.items(): |
if len(v) < mean_num: |
lack_num = mean_num - len(v) |
indices.extend(np.random.choice(v, lack_num)) |
indices.extend(v) |
indices = np.array(indices) |
indices = indices[:all_size] |
printlog('using strategy 3, mean_num: {}, origin_len: {}, balanced_len: {}'.format(mean_num, len(self.dataset), len(indices))) |
num_repeat = (all_size-1) // indices.shape[0] + 1 |
indices = np.tile(indices, num_repeat) |
indices = indices[:all_size] |
np.random.shuffle(indices) |
beg = self.total_size * self.rank |
indices = indices[beg:beg+self.total_size] |
elif self.shuffle_strategy == 6: |
np.random.seed(self.random_seed) |
labels = self.dataset.labels |
printlog('using shuffle strategy 6, initializing class map...') |
class2id = collections.defaultdict(list) |
for i,l in enumerate(labels): |
class2id[l].append(i) |
mini_indices = [] |
for pid, idxs in class2id.items(): |
if len(idxs) < self.imageNumPerClass: |
idxs = idxs + list(np.random.choice(idxs, size=self.imageNumPerClass-len(idxs), replace=True)) |
elif len(idxs) % self.imageNumPerClass != 0: |
add_num = int(len(idxs) // self.imageNumPerClass + 1) * self.imageNumPerClass - len(idxs) |
idxs = idxs + list(np.random.choice(idxs, size=add_num, replace=True)) |
assert len(idxs) % self.imageNumPerClass == 0 |
mini_indices.extend([idxs[i:i+self.imageNumPerClass] for i in range(0, len(idxs), self.imageNumPerClass)]) |
np.random.shuffle(mini_indices) |
indices = np.array(mini_indices).reshape(-1) |
all_size = self.total_size * self.world_size |
indices = indices[:all_size] |
num_repeat = (all_size-1) // indices.shape[0] + 1 |
indices = np.tile(indices, num_repeat) |
indices = indices[:all_size] |
beg = self.total_size * self.rank |
indices = indices[beg:beg+self.total_size] |
elif self.shuffle_strategy == 7: |
np.random.seed(self.random_seed) |
labels = self.dataset.labels |
printlog('using shuffle strategy 7, initializing class map...') |
class2id = collections.defaultdict(list) |
for i,l in enumerate(labels): |
class2id[l].append(i) |
mini_indices = [] |
for pid, idxs in class2id.items(): |
if len(idxs) < self.imageNumPerClass: |
idxs = idxs + list(np.random.choice(idxs, size=self.imageNumPerClass-len(idxs), replace=True)) |
elif len(idxs) % self.imageNumPerClass != 0: |
add_num = int(len(idxs) // self.imageNumPerClass + 1) * self.imageNumPerClass - len(idxs) |
idxs = idxs + list(np.random.choice(idxs, size=add_num, replace=True)) |
assert len(idxs) % self.imageNumPerClass == 0 |
mini_indices.extend([idxs[i:i+self.imageNumPerClass] for i in range(0, len(idxs), self.imageNumPerClass)]) |
np.random.shuffle(mini_indices) |
indices = np.array(mini_indices).reshape(-1) |
all_size = self.total_size * self.world_size |
indices = indices[:all_size] |
num_repeat = (all_size-1) // indices.shape[0] + 1 |
for repeat in range(num_repeat)-1: |
for pid, idxs in class2id.items(): |
if len(idxs) < self.imageNumPerClass: |
idxs = idxs + list(np.random.choice(idxs, size=self.imageNumPerClass-len(idxs), replace=True)) |
elif len(idxs) % self.imageNumPerClass != 0: |
add_num = int(len(idxs) // self.imageNumPerClass + 1) * self.imageNumPerClass - len(idxs) |
idxs = idxs + list(np.random.choice(idxs, size=add_num, replace=True)) |
assert len(idxs) % self.imageNumPerClass == 0 |
mini_indices.extend([idxs[i:i+self.imageNumPerClass] for i in range(0, len(idxs), self.imageNumPerClass)]) |
np.random.shuffle(mini_indices) |
indices = np.array(mini_indices).reshape(-1) |
all_size = self.total_size * self.world_size |
indices = indices[:all_size] |
beg = self.total_size * self.rank |
indices = indices[beg:beg+self.total_size] |
elif self.shuffle_strategy == 8: |
np.random.seed(self.random_seed) |
labels = self.dataset.labels |
print('using shuffle strategy 8, initializing class map...') |
class2id = collections.defaultdict(list) |
for i,l in enumerate(labels): |
class2id[l].append(i) |
pids = set() |
cls_pids_idxs = collections.defaultdict(list) |
for pid, idxs in class2id.items(): |
if len(idxs) < self.imageNumPerClass: |
idxs = idxs + list(np.random.choice(idxs, size=self.imageNumPerClass-len(idxs), replace=True)) |
elif len(idxs) % self.imageNumPerClass != 0: |
add_num = int(len(idxs) // self.imageNumPerClass + 1) * self.imageNumPerClass - len(idxs) |
idxs = idxs + list(np.random.choice(idxs, size=add_num, replace=True)) |
assert len(idxs) % self.imageNumPerClass == 0 |
np.random.shuffle(idxs) |
idxs = [idxs[i:(i+self.imageNumPerClass)] for i in range(0, len(idxs), self.imageNumPerClass)] |
ptr = 0 |
values = [ptr, idxs] |
cls_pids_idxs[pid] = values |
pids.add(pid) |
indices = [] |
classnum_per_batch = self.batch_size // self.imageNumPerClass |
while len(pids) >= classnum_per_batch: |
batch = [] |
sub_pids = [] |
for i, pid in enumerate(pids): |
if i >= classnum_per_batch: |
break |
sub_pids.append(pid) |
for pid in sub_pids: |
ptr = cls_pids_idxs[pid][0] |
idxs = cls_pids_idxs[pid][1] |
batch.extend(idxs[ptr]) |
if ptr + 1 >= len(idxs): |
pids.remove(pid) |
else: |
cls_pids_idxs[pid][0] += 1 |
indices.append(batch) |
np.random.shuffle(indices) |
indices = np.array(indices).reshape(-1) |
all_size = self.total_size * self.world_size |
indices = indices[:all_size] |
num_repeat = (all_size-1) // indices.shape[0] + 1 |
indices = np.tile(indices, num_repeat) |
indices = indices[:all_size] |
beg = self.total_size * self.rank |
indices = indices[beg:beg+self.total_size] |
else: |
raise RuntimeError('unknow shuffle strategy') |
assert len(indices) == self.total_size |
return indices[(self.last_iter+1)*self.batch_size:] |
def __len__(self): |
return self.total_size - (self.last_iter+1)*self.batch_size |
def save(self): |
torch.save({'task_name': self.task_name, |
'task_size': self.world_size, |
'task_rank': self.rank, |
'ret_file': self.indices}, self.this_ret_path) |
printlog("save sampler file ------> {}".format(self.this_ret_path)) |
class RandomIdentitySampler(Sampler): |
""" |
Randomly sample N identities, then for each identity, |
randomly sample K instances, therefore batch size is N*K. |
Args: |
- data_source (list): list of (img_path, pid, camid). |
- imageNumPerClass (int): number of instances per identity in a batch. |
- batch_size (int): number of examples in a batch. |
""" |
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1, |
shuffle_strategy=0, random_seed=0, imageNumPerClass=4, ret_save_path=None): |
self.batch_size = batch_size |
self.world_size = world_size if world_size is not None else dist.get_world_size() |
self.num_instances = imageNumPerClass |
self.num_pids_per_batch = self.batch_size // self.num_instances |
self.index_dic = defaultdict(list) |
self.random_seed = random_seed |
self.total_iter = total_iter |
self.total_size = self.total_iter*self.batch_size |
self.last_iter = last_iter |
self.dataset = dataset |
labels = self.dataset.labels |
printlog('using RandomIdentityBatchSampler, initializing class map...') |
self.index_dic = collections.defaultdict(list) |
for i,l in enumerate(labels): |
self.index_dic[l].append(i) |
self.pids = np.array(list(self.index_dic.keys())) |
self.rank = rank if rank is not None else dist.get_rank() |
def __iter__(self): |
np.random.seed(self.random_seed) |
self._seed = int(self.random_seed) |
final_idxs = self.sample_list() |
length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) |
final_idxs = self.__fetch_current_node_idxs(final_idxs, length) |
return iter(final_idxs) |
def __fetch_current_node_idxs(self, final_idxs, length): |
total_num = len(final_idxs) |
block_num = (length // self.batch_size) |
index_target = [] |
for i in range(0, block_num * self.world_size, self.world_size): |
index = range(self.batch_size * self.rank + self.batch_size * i, min(self.batch_size * self.rank + self.batch_size * (i+1), total_num)) |
index_target.extend(index) |
index_target_npy = np.array(index_target) |
final_idxs = list(np.array(final_idxs)[index_target_npy]) |
return final_idxs |
def batch_sample_list(self): |
avai_pids = copy.deepcopy(self.pids.tolist()) |
batch_idxs_dict = {} |
batch_indices = [] |
while len(avai_pids) >= self.num_pids_per_batch: |
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() |
for pid in selected_pids: |
if pid not in batch_idxs_dict or len(batch_idxs_dict[pid]) < self.num_instances: |
idxs = copy.deepcopy(self.index_dic[pid]) |
if len(idxs) < self.num_instances: |
idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() |
np.random.shuffle(idxs) |
batch_idxs_dict[pid] = idxs |
avai_idxs = batch_idxs_dict[pid] |
for _ in range(self.num_instances): |
batch_indices.append(avai_idxs.pop(0)) |
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) |
return batch_indices |
def sample_list(self): |
all_size = self.total_size * self.world_size |
all_indices = list() |
while len(all_indices) <= all_size: |
all_indices.extend(self.batch_sample_list()) |
return all_indices[:all_size] |
def __len__(self): |
return self.total_size - (self.last_iter+1)*self.batch_size |
class RandomIdentityBatchSampler(Sampler): |
""" |
Randomly sample N identities, then for each identity, |
randomly sample K instances, therefore batch size is N*K. |
Args: |
- data_source (list): list of (img_path, pid, camid). |
- imageNumPerClass (int): number of instances per identity in a batch. |
- batch_size (int): number of examples in a batch. |
""" |
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1, |
shuffle_strategy=0, random_seed=0, imageNumPerClass=4, ret_save_path=None): |
self.batch_size = batch_size |
self.world_size = world_size if world_size is not None else dist.get_world_size() |
self.num_instances = imageNumPerClass |
self.num_pids_per_batch = self.batch_size // self.num_instances |
self.index_dic = defaultdict(list) |
self.random_seed = random_seed |
self.total_iter = total_iter |
self.total_size = self.total_iter*self.batch_size |
self.last_iter = last_iter |
self.dataset = dataset |
labels = self.dataset.labels |
printlog('using RandomIdentityBatchSampler, initializing class map...') |
self.index_dic = collections.defaultdict(list) |
for i,l in enumerate(labels): |
self.index_dic[l].append(i) |
self.pids = np.array(list(self.index_dic.keys())) |
self.rank = rank if rank is not None else dist.get_rank() |
def __iter__(self): |
np.random.seed(self.random_seed) |
self._seed = int(self.random_seed) |
final_idxs_batches = self.sample_list() |
final_idxs = self.__fetch_current_node_idxs(final_idxs_batches) |
return iter(final_idxs) |
def __fetch_current_node_idxs(self, final_idxs_batches): |
res = [] |
for final_idxs in final_idxs_batches: |
total_num = len(final_idxs) |
length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) |
block_num = (length // self.batch_size) |
index_target = [] |
for i in range(0, block_num * self.world_size, self.world_size): |
index = range(self.batch_size * self.rank + self.batch_size * i, min(self.batch_size * self.rank + self.batch_size * (i+1), total_num)) |
index_target.extend(index) |
index_target_npy = np.array(index_target) |
final_idxs = list(np.array(final_idxs)[index_target_npy]) |
res.extend(final_idxs) |
res = res[:self.total_size] |
return res |
def batch_sample_list(self): |
avai_pids = copy.deepcopy(self.pids.tolist()) |
batch_idxs_dict = {} |
batch_indices = [] |
while len(avai_pids) >= self.num_pids_per_batch: |
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() |
for pid in selected_pids: |
if pid not in batch_idxs_dict or len(batch_idxs_dict[pid]) < self.num_instances: |
idxs = copy.deepcopy(self.index_dic[pid]) |
if len(idxs) < self.num_instances: |
idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() |
np.random.shuffle(idxs) |
batch_idxs_dict[pid] = idxs |
avai_idxs = batch_idxs_dict[pid] |
for _ in range(self.num_instances): |
batch_indices.append(avai_idxs.pop(0)) |
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) |
return batch_indices |
def sample_list(self): |
all_size = self.total_size * self.world_size |
all_indices = list() |
all_indices_batch = list() |
while len(all_indices) <= all_size: |
all_indices.extend(self.batch_sample_list()) |
all_indices_batch.append(self.batch_sample_list()) |
return all_indices_batch |
def __len__(self): |
return self.total_size - (self.last_iter+1)*self.batch_size |
class RandomIdentityEpochBatchSampler(Sampler): |
""" |
Randomly sample N identities, then for each identity, |
randomly sample K instances, therefore batch size is N*K. |
Args: |
- data_source (list): list of (img_path, pid, camid). |
- imageNumPerClass (int): number of instances per identity in a batch. |
- batch_size (int): number of examples in a batch. |
""" |
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1, |
shuffle_strategy=0, random_seed=0, imageNumPerClass=4, ret_save_path=None): |
self.batch_size = batch_size |
self.world_size = world_size if world_size is not None else dist.get_world_size() |
self.num_instances = imageNumPerClass |
self.num_pids_per_batch = self.batch_size // self.num_instances |
self.index_dic = defaultdict(list) |
self.random_seed = random_seed |
self.total_iter = total_iter |
self.total_size = self.total_iter*self.batch_size |
self.last_iter = last_iter |
self.dataset = dataset |
labels = self.dataset.labels |
printlog('using RandomIdentityBatchSampler, initializing class map...') |
self.index_dic = collections.defaultdict(list) |
for i,l in enumerate(labels): |
self.index_dic[l].append(i) |
self.pids = np.array(list(self.index_dic.keys())) |
self.rank = rank if rank is not None else dist.get_rank() |
self.length = 0 |
for pid in self.pids: |
idxs = self.index_dic[pid] |
num = len(idxs) |
if num < self.num_instances: |
num = self.num_instances |
self.length += num - num % self.num_instances |
self.length //= self.world_size |
def __iter__(self): |
np.random.seed(self.random_seed) |
self._seed = int(self.random_seed) |
final_idxs_batches = self.sample_list() |
length = int(math.ceil(len(final_idxs_batches) * 1.0 / self.world_size)) |
final_idxs = self.__fetch_current_node_idxs(final_idxs_batches, length) |
self.length = len(final_idxs) |
return iter(final_idxs) |
def set_epoch(self, epoch): |
self.random_seed = self.random_seed + epoch |
def __fetch_current_node_idxs(self, final_idxs, length): |
total_num = len(final_idxs) |
block_num = (length // self.batch_size) |
index_target = [] |
for i in range(0, block_num * self.world_size, self.world_size): |
index = range(self.batch_size * self.rank + self.batch_size * i, min(self.batch_size * self.rank + self.batch_size * (i+1), total_num)) |
index_target.extend(index) |
index_target_npy = np.array(index_target) |
final_idxs = list(np.array(final_idxs)[index_target_npy]) |
return final_idxs |
def sample_list(self): |
avai_pids = copy.deepcopy(self.pids.tolist()) |
batch_idxs_dict = {} |
batch_indices = [] |
while len(avai_pids) >= self.num_pids_per_batch: |
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() |
for pid in selected_pids: |
if pid not in batch_idxs_dict or len(batch_idxs_dict[pid]) < self.num_instances: |
idxs = copy.deepcopy(self.index_dic[pid]) |
if len(idxs) < self.num_instances: |
idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() |
np.random.shuffle(idxs) |
batch_idxs_dict[pid] = idxs |
avai_idxs = batch_idxs_dict[pid] |
for _ in range(self.num_instances): |
batch_indices.append(avai_idxs.pop(0)) |
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) |
return batch_indices |
def __len__(self): |
return self.length |
class DistributedGivenSizeSampler(Sampler): |
def __init__(self, dataset, given_size=None, dup_shuffle=False, world_size=None, rank=None): |
if world_size is None: |
world_size = dist.get_world_size() |
if rank is None: |
rank = dist.get_rank() |
assert rank < world_size |
self.dataset = dataset |
self.dup_shuffle = dup_shuffle |
self.world_size = world_size |
self.rank = rank |
self.epoch = 0 |
if given_size is None: |
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.world_size)) |
else: |
self.num_samples = int(math.ceil(given_size * 1.0 / self.world_size)) |
self.total_size = self.num_samples * self.world_size |
if self.dup_shuffle: |
self.offset = 0 |
self.g = torch.Generator() |
self.indices = self.gen_new_list(self.g) |
def __iter__(self): |
if self.dup_shuffle: |
if self.offset == self.world_size: |
self.indices = self.gen_new_list(self.g) |
self.offset = 0 |
beg = self.offset*self.num_samples |
indices = self.indices[beg:beg+self.num_samples] |
self.offset += 1 |
else: |
g = torch.Generator() |
g.manual_seed(self.epoch) |
indices = self.gen_new_list(g) |
offset = self.num_samples * self.rank |
indices = indices[offset:offset + self.num_samples] |
assert len(indices) == self.num_samples |
return iter(indices) |
def gen_new_list(self, g): |
origin_indices = list(torch.randperm(len(self.dataset), generator=g)) |
indices = origin_indices[:] |
indices = indices[:self.total_size] |
extra = self.total_size - len(origin_indices) |
while self.total_size - len(indices) > 0: |
intake = self.total_size - len(indices) |
indices += origin_indices[:intake] |
assert len(indices) == self.total_size |
return indices |
def __len__(self): |
return self.num_samples |
def set_epoch(self, epoch): |
self.epoch = epoch |
class DistributedGroupSampler(Sampler): |
"""Sampler that restricts data loading to a subset of the dataset. |
It is especially useful in conjunction with |
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each |
process can pass a DistributedSampler instance as a DataLoader sampler, |
and load a subset of the original dataset that is exclusive to it. |
.. note:: |
Dataset is assumed to be of constant size. |
Arguments: |
dataset: Dataset used for sampling. |
num_replicas (optional): Number of processes participating in |
distributed training. |
rank (optional): Rank of the current process within num_replicas. |
seed (int, optional): random seed used to shuffle the sampler if |
``shuffle=True``. This number should be identical across all |
processes in the distributed group. Default: 0. |
""" |
def __init__(self, |
dataset, |
batch_size=1, |
num_replicas=None, |
world_size=None, |
rank=None, |
total_iter=-1, |
random_seed=0, |
last_iter=-1): |
_rank, _num_replicas = rank, world_size |
if num_replicas is None: |
num_replicas = _num_replicas |
if rank is None: |
rank = _rank |
self.dataset = dataset |
self.samples_per_gpu = batch_size |
self.num_replicas = num_replicas |
self.rank = rank |
self.seed = random_seed if random_seed is not None else 0 |
self.total_iter = total_iter |
self.batch_size = batch_size |
assert hasattr(self.dataset, 'flag') |
dataset_flag = self.dataset.flag |
self.flag = np.tile(dataset_flag, (total_iter - 1) * batch_size // len(self.dataset) + 1)[:total_iter * batch_size] |
self.flag = self.flag[:total_iter*self.batch_size] |
self.group_sizes = np.bincount(self.flag) |
self.num_samples = 0 |
for i, j in enumerate(self.group_sizes): |
self.num_samples += int( |
math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / |
self.num_replicas)) * self.samples_per_gpu |
self.total_size = self.num_samples * self.num_replicas |
def __iter__(self): |
g = torch.Generator() |
g.manual_seed(self.seed) |
indices = [] |
for i, size in enumerate(self.group_sizes): |
if size > 0: |
indice = np.where(self.flag == i)[0] |
assert len(indice) == size |
indice = indice[list( |
torch.randperm(int(size), generator=g).numpy())].tolist() |
extra = int( |
math.ceil( |
size * 1.0 / self.samples_per_gpu / self.num_replicas) |
) * self.samples_per_gpu * self.num_replicas - len(indice) |
tmp = indice.copy() |
for _ in range(extra // size): |
indice.extend(tmp) |
indice.extend(tmp[:extra % size]) |
indices.extend(indice) |
assert len(indices) == self.total_size |
indices = [ |
indices[j] for i in list( |
torch.randperm( |
len(indices) // self.samples_per_gpu, generator=g)) |
for j in range(i * self.samples_per_gpu, (i + 1) * |
self.samples_per_gpu) |
] |
indices = [i % len(self.dataset.flag) for i in indices] |
offset = self.num_samples * self.rank |
indices = indices[offset:offset + self.num_samples] |
assert len(indices) == self.num_samples |
return iter(indices) |
def __len__(self): |
return self.num_samples |
class BatchSampler(Sampler): |
r"""Wraps another sampler to yield a mini-batch of indices. |
Args: |
sampler (Sampler or Iterable): Base sampler. Can be any iterable object |
batch_size (int): Size of mini-batch. |
drop_last (bool): If ``True``, the sampler will drop the last batch if |
its size would be less than ``batch_size`` |
Example: |
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) |
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] |
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) |
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
""" |
def __init__(self, sampler, batch_size, drop_last): |
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ |
batch_size <= 0: |
raise ValueError("batch_size should be a positive integer value, " |
"but got batch_size={}".format(batch_size)) |
if not isinstance(drop_last, bool): |
raise ValueError("drop_last should be a boolean value, but got " |
"drop_last={}".format(drop_last)) |
self.sampler = sampler |
self.batch_size = batch_size |
self.drop_last = drop_last |
self.flag = self.sampler.flag |
def __iter__(self): |
ret = [] |
batch = [] |
for idx in self.sampler: |
batch.append(idx) |
if len(batch) == self.batch_size: |
batch = tuple(batch) |
for i in batch: |
if self.flag[i]!=self.flag[batch[0]]: |
from IPython import embed;embed() |
ret.extend(batch) |
batch = [] |
if len(batch) > 0 and not self.drop_last: |
ret.extend(batch) |
return iter(ret) |
def __len__(self): |
if self.drop_last: |
return len(self.sampler) // self.batch_size |
else: |
return (len(self.sampler) + self.batch_size - 1) // self.batch_size |
class DistributedSequentialSampler(Sampler): |
def __init__(self, dataset, world_size=None, rank=None): |
if world_size == None: |
world_size = dist.get_world_size() |
if rank == None: |
rank = dist.get_rank() |
self.dataset = dataset |
self.world_size = world_size |
self.rank = rank |
assert len(self.dataset) >= self.world_size, f'{len(self.dataset)} vs {self.world_size}' |
sub_num = int(math.ceil(len(self.dataset) * 1.0 / self.world_size)) |
self.beg = sub_num * self.rank |
self.end = min(self.beg+sub_num, len(self.dataset)) |
def __iter__(self): |
indices = list(range(self.beg, self.end)) |
return iter(indices) |
def __len__(self): |
return self.end - self.beg |
def bcast_value(value): |
v = torch.Tensor([value]) |
dist.broadcast(v, root=0) |
return v.item() |
def gather_tensors(input_array): |
world_size = dist.get_world_size() |
rank = dist.get_rank() |
myshape = input_array.shape |
mycount = input_array.size |
shape_tensor = torch.Tensor(np.array(myshape)) |
all_shape = [torch.Tensor(np.array(myshape)) for i in range(world_size)] |
dist.gather(all_shape, shape_tensor, root=0) |
if rank == 0: |
all_shape = [x.numpy() for x in all_shape] |
all_count = [int(x.prod()) for x in all_shape] |
all_shape = [list(map(int, x)) for x in all_shape] |
max_count = max(all_count) |
else: |
max_count = 0 |
max_count = int(bcast_value(max_count)) |
output_tensors = [torch.Tensor(max_count) for i in range(world_size)] |
padded_input_array = np.zeros(max_count) |
padded_input_array[:mycount] = input_array.reshape(-1) |
input_tensor = torch.Tensor(padded_input_array) |
dist.gather(output_tensors, input_tensor, root=0) |
if rank == 0: |
padded_output = [x.numpy() for x in output_tensors] |
output = [x[:all_count[i]].reshape(all_shape[i]) for i,x in enumerate(padded_output)] |
else: |
output = None |
return output |
def allgatherv(input_tensor, flatten=False): |
world_size = dist.get_world_size() |
rank = dist.get_rank() |
myshape = torch.Tensor([*input_tensor.size()]).long() |
mycount = input_tensor.numel() |
all_shape = [torch.zeros_like(myshape, dtype=torch.long) for i in range(world_size)] |
dist.all_gather(all_shape, myshape) |
all_count = [int(x.prod()) for x in all_shape] |
max_count = max(all_count) |
output_tensors = [torch.zeros(max_count).to(input_tensor) for i in range(world_size)] |
padded_input = torch.zeros(max_count).to(input_tensor) |
padded_input[:mycount] = input_tensor.view(-1) |
dist.all_gather(output_tensors, padded_input) |
if flatten: |
output = torch.cat([x[:all_count[i]] for i,x in enumerate(output_tensors)]) |
else: |
output = [x[:all_count[i]].view(*all_shape[i]) for i,x in enumerate(output_tensors)] |
return output |
def simple_group_split(world_size, rank, num_groups): |
groups = [] |
rank_list = np.split(np.arange(world_size), num_groups) |
rank_list = [list(map(int, x)) for x in rank_list] |
for i in range(num_groups): |
groups.append(dist.new_group(ranks=rank_list[i])) |
group_size = world_size // num_groups |
return groups[rank//group_size] |
def specific_group_split(world_size, rank, group_spec): |
assert type(group_spec) is list |
assert all(map(lambda x: type(x) is int, group_spec)) |
num_groups = len(group_spec) |
splits = np.sum(group_spec) |
assert world_size % splits == 0 |
unit = int(world_size / splits) |
group_spec = [x*unit for x in group_spec] |
groups = [] |
roots = [] |
last = 0 |
group_info = edict() |
for i,gs in enumerate(group_spec): |
ranks = list(map(int, np.arange(last, last+gs))) |
groups.append(dist.new_group(ranks=ranks)) |
roots.append(last) |
if rank in ranks: |
group_info.group = groups[-1] |
group_info.task_size = gs |
group_info.task_id = i |
group_info.task_sub_id = rank - last |
group_info.task_root = last |
last += gs |
group_info.task_roots = roots |
group_info.num_groups = num_groups |
return group_info |
def vreduce(x, tensor, group=None): |
y = tensor.clone() |
if group is not None: |
dist.all_reduce(y, group=group) |
else: |
dist.all_reduce(y) |
x.update(y.item()) |
def vgather(x_list, x): |
dist.all_gather(x_list,torch.Tensor([x]).cuda()) |
def reduce_dict(input_dict, task_size, task_rank, group=None, average=True): |
""" |
Args: |
input_dict (dict): all the values will be reduced |
average (bool): whether to do average or sum |
Reduce the values in the dictionary from all processes so that all processes |
have the averaged results. Returns a dict with the same fields as |
input_dict, after reduction. |
""" |
world_size = task_size |
if world_size < 2: |
return input_dict |
with torch.no_grad(): |
names = [] |
values = [] |
for k in sorted(input_dict.keys()): |
names.append(k) |
values.append(input_dict[k]) |
values = torch.stack(values, dim=0) |
dist.all_reduce(values, group=group) |
if average: |
values /= world_size |
reduced_dict = {k: v for k, v in zip(names, values)} |
return reduced_dict |