import numpy as np import shutil import torch import os import io import copy import math import logging from collections import defaultdict import torch.distributed as dist from torch.nn import BatchNorm2d from torch.utils.checkpoint import checkpoint import cv2 import subprocess from PIL import Image import core.fp16 as fp16 from typing import Optional, List from torch import Tensor import torch._utils try: torch._utils._rebuild_tensor_v2 except AttributeError: def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) tensor.requires_grad = requires_grad tensor._backward_hooks = backward_hooks return tensor torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 import torch.nn as nn cv2.ocl.setUseOpenCL(False) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, length): self.length = length self.reset() def reset(self): self.history = [] self.val = 0 self.avg = 0 def empty(self): return len(self.history) == 0 def update(self, val): self.history.append(val) if self.length > 0 and len(self.history) > self.length: del self.history[0] self.val = val self.avg = np.mean(self.history) class AverageMinMaxMeter(object): """Computes and stores the average and current value""" def __init__(self, length): self.length = length self.reset() def reset(self): self.history = [] self.val = 0 self.min = 10000 self.max = 0 self.avg = 0 def empty(self): return len(self.history) == 0 def update(self, val): self.history.append(val) if self.length > 0 and len(self.history) > self.length: del self.history[0] self.val = val self.avg = np.mean(self.history) self.min = min(self.min, val) self.max = max(self.max, val) def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.reshape(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def accuracy_multi(output, target): pred = (output > 0).float() tf = (pred == target).float() acc = tf.sum() / output.size(0) / output.size(1) * 100 return acc def save_state(state, path, step): path, filename = os.path.split(path) assert path != '' if not os.path.exists(path): os.makedirs(path, exist_ok=True) print('saving to {}/{}_iter_{}.pth.tar'.format(path, filename, step)) try: torch.save(state, '{}/{}_iter_{}.pth.tar'.format(path, filename, step)) except TypeError as e: print(f"Full key list: state['state_dict'].keys(): {state['state_dict'].keys()}") raise e def load_last_iter(path): if os.path.isfile(path): checkpoint = torch.load(path, map_location='cpu') dist.barrier() print("=> loaded last_iter={} from {}".format(checkpoint['step'], path)) dist.barrier() return checkpoint['step'] else: raise RuntimeError("=> no checkpoint found at {}".format(path)) def remove_prefix_string(string, prefix): assert string.startswith(prefix), "can not remove prefix." return string[len(prefix):] def remove_prefix_from_state_dict(state_dict, prefix): for old_key in list(state_dict.keys()): if old_key.startswith(prefix): new_key = remove_prefix_string(old_key, prefix) state_dict[new_key] = state_dict.pop(old_key) def load_state(path, model, ignore=[], optimizer=None, cuda=False, recover=False, remove_prefix=None, strict=False): def map_func_cuda(storage, location): return storage.cuda() def map_func_cpu(storage, location): return storage.cpu() if cuda: map_func = map_func_cuda else: map_func = map_func_cpu if os.path.isfile(path): print("=> loading checkpoint '{}'".format(path)) checkpoint = torch.load(path, map_location=map_func) if 'state_dict' in checkpoint.keys(): pretrained_state_dict_new = checkpoint['state_dict'] else: pretrained_state_dict_new = checkpoint pretrained_state_dict = dict() for k in list(pretrained_state_dict_new.keys()): if '_orig_mod.' in k: k_new = k.split('_orig_mod.')[1] pretrained_state_dict[k_new] = pretrained_state_dict_new[k] else: pretrained_state_dict[k] = pretrained_state_dict_new[k] if len(ignore) > 0: assert optimizer == None for k in list(pretrained_state_dict.keys()): flag = False for prefix in ignore: if k.startswith(prefix): flag = True the_prefix = prefix break if flag: print('ignoring {} (prefix: {})'.format(k, the_prefix)) del pretrained_state_dict[k] if remove_prefix: remove_prefix_from_state_dict(pretrained_state_dict, remove_prefix) model.load_state_dict(pretrained_state_dict, strict=strict) dist.barrier() if dist.get_rank() == 0: keys1 = set(pretrained_state_dict.keys()) keys2 = set([k for k,_ in model.named_parameters()]) not_loaded = keys2 - keys1 for k in not_loaded: print('caution: {} not loaded'.format(k)) dist.barrier() if optimizer != None: assert len(ignore) == 0 #TODO currently a workaround for gpu memory leak optimizer.load_state_dict(checkpoint['optimizer']) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() else: state[k] = v print("k: {} do not move to cuda".format(k)) print("=> loaded checkpoint '{}' (step {})".format(path, checkpoint['step'])) return checkpoint['step'] if recover: return checkpoint['step'] else: assert False, "=> no checkpoint found at '{}'".format(path) def load_state_model(model, state, ginfo): if ginfo.task_rank == 0: printlog(f'======= loading model state for task {ginfo.task_id} ... =======') msg = model.load_state_dict(state, strict=False) state_keys = set(state.keys()) model_keys = set(model.state_dict().keys()) missing_keys = model_keys - state_keys if ginfo.task_rank == 0: for k in missing_keys: printlog(f'missing key: {k}') printlog(f'load msg: {msg}') def load_state_optimizer(optimizer, state, ginfo): if ginfo.task_rank == 0: printlog(f'======= loading optimizer state for task {ginfo.task_id} ... =======') optimizer.load_state_dict(state) def create_logger(name, log_file, level=logging.INFO): l = logging.getLogger(name) formatter = logging.Formatter('[%(asctime)s][%(filename)20s][line:%(lineno)4d][%(levelname)8s] %(message)s') fh = logging.FileHandler(log_file) fh.setFormatter(formatter) sh = logging.StreamHandler() sh.setFormatter(formatter) l.setLevel(level) l.addHandler(fh) l.addHandler(sh) return l class IterLRScheduler(object): def __init__(self, optimizer, milestones, lr_mults, last_iter=-1): assert len(milestones) == len(lr_mults), "{} vs {}".format(len(milestones), len(lr_mults)) self.milestones = milestones self.lr_mults = lr_mults if not isinstance(optimizer, torch.optim.Optimizer) and not isinstance(optimizer, fp16.FP16_Optimizer): raise TypeError('{} is not an Optimizer'.format( type(optimizer).__name__)) self.optimizer = optimizer for i, group in enumerate(optimizer.param_groups): if 'lr' not in group: raise KeyError("param 'lr' is not specified " "in param_groups[{}] when resuming an optimizer".format(i)) self.last_iter = last_iter def _get_lr(self): try: pos = self.milestones.index(self.last_iter) except ValueError: return list(map(lambda group: group['lr'], self.optimizer.param_groups)) except: raise Exception('wtf?') return list(map(lambda group: group['lr']*self.lr_mults[pos], self.optimizer.param_groups)) def get_lr(self): return list(map(lambda group: group['lr'], self.optimizer.param_groups)) def step(self, this_iter=None): if this_iter is None: this_iter = self.last_iter + 1 self.last_iter = this_iter for param_group, lr in zip(self.optimizer.param_groups, self._get_lr()): param_group['lr'] = lr def reset_bn(module): if isinstance(module, BatchNorm2d) or isinstance(module, torch.nn.SyncBatchNorm): module.running_mean = torch.zeros_like(module.running_mean) module.running_var = torch.ones_like(module.running_var) def pil_loader(img_str): buff = io.BytesIO(img_str) with Image.open(buff) as img: img = img.convert('RGB') return img def cv2_loader(img_str): img_array = np.frombuffer(img_str, dtype=np.uint8) return cv2.imdecode(img_array, cv2.IMREAD_COLOR) def param_groups(model): bn_group = [] fc_group = [] feature_group = [] normal_group = [] bn_names = set() for name,m in model.named_modules(): if isinstance(m, BatchNorm2d) or isinstance(m, torch.nn.SyncBatchNorm): if not m.weight is None: bn_group.append(m.weight) bn_names.add(name+'.weight') if not m.bias is None: bn_group.append(m.bias) bn_names.add(name+'.bias') for name,param in model.named_parameters(): if name in bn_names: continue elif name.startswith('module.base.fc'): feature_group.append(param) elif name.startswith('module.logits'): fc_group.append(param) else: normal_group.append(param) return bn_group, feature_group, fc_group, normal_group def clip_grad_value(parameters, clip_value): clip_value = float(clip_value) for p in filter(lambda p: p.grad is not None, parameters): p.grad.data.clamp_(min=-clip_value, max=clip_value) def compute_grad_norm(parameters): parameters = list(filter(lambda p: p.grad is not None, parameters)) total_norm = 0 for p in parameters: param_norm = p.grad.data.norm(2) total_norm += param_norm ** 2 total_norm = total_norm ** 0.5 return total_norm class SIMSELoss(nn.Module): def __init__(self): super(SIMSELoss, self).__init__() def forward(self, pred, real): diffs = real - pred n = torch.numel(diffs.data) mse = torch.sum(diffs.pow(2)) / n simse = torch.sum(diffs).pow(2) / (n ** 2) return mse - simse class GradRejust(torch.autograd.Function): @staticmethod def forward(ctx, x, grad_scale): ctx.grad_scale = grad_scale return x.view_as(x) @staticmethod def backward(ctx, grad_output): return ctx.grad_scale * grad_output, None def grad_rejust(x, grad_scale=1.0): return GradRejust.apply(x, grad_scale) def count_parameters_num(model): count = 0 count_fc = 0 param_dict = {name:param for name,param in model.named_parameters()} param_keys = param_dict.keys() for m_name, m in model.named_modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, torch.nn.SyncBatchNorm): weight_name = m_name + '.weight' bias_name = m_name + '.bias' if weight_name in param_keys: temp_params = param_dict[weight_name] count += temp_params.data.nelement() if bias_name in param_keys: temp_params = param_dict[bias_name] count += temp_params.data.nelement() elif isinstance(m, nn.Linear): weight_name = m_name + '.weight' bias_name = m_name + '.bias' if weight_name in param_keys: temp_params = param_dict[weight_name] count_fc += temp_params.data.nelement() if bias_name in param_keys: temp_params = param_dict[bias_name] count_fc += temp_params.data.nelement() sync_print('Number of conv/bn params: %.2fM' % (count / 1e6)) sync_print('Number of linear params: %.2fM' % (count_fc / 1e6)) def get_gpu_memory_map(): """Get the current gpu usage. Returns ------- usage: dict Keys are device ids as integers. Values are memory usage as integers in MB. """ result = subprocess.check_output( [ 'nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader' ], encoding='utf-8') # Convert lines into a dictionary gpu_memory = [int(x) for x in result.strip().split('\n')] gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) return gpu_memory_map def param_group_no_wd(model): pgroup_no_wd = [] names_no_wd = [] pgroup_normal = [] type2num = defaultdict(lambda : 0) for name,m in model.named_modules(): if isinstance(m, torch.nn.Conv2d): if m.bias is not None: pgroup_no_wd.append(m.bias) names_no_wd.append(name+'.bias') type2num[m.__class__.__name__+'.bias'] += 1 elif isinstance(m, torch.nn.Linear): if m.bias is not None: pgroup_no_wd.append(m.bias) names_no_wd.append(name+'.bias') type2num[m.__class__.__name__+'.bias'] += 1 elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d) or isinstance(m, torch.nn.SyncBatchNorm): if m.weight is not None: pgroup_no_wd.append(m.weight) names_no_wd.append(name+'.weight') type2num[m.__class__.__name__+'.weight'] += 1 if m.bias is not None: pgroup_no_wd.append(m.bias) names_no_wd.append(name+'.bias') type2num[m.__class__.__name__+'.bias'] += 1 for name,p in model.named_parameters(): if not name in names_no_wd: pgroup_normal.append(p) return [{'params': pgroup_normal}, {'params': pgroup_no_wd, 'weight_decay': 0.0}], type2num def freeze_bn(model): names = [] for name, m in model.named_modules(): if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.SyncBatchNorm): m.eval() names.append(name) return names def named_buffers(self, memo=None, prefix=''): if memo is None: memo = set() for name, b in self._buffers.items(): if b is not None and b not in memo: memo.add(b) yield prefix + ('.' if prefix else '') + name, b for mname, module in self.named_children(): submodule_prefix = prefix + ('.' if prefix else '') + mname for name, b in module.named_buffers(memo, submodule_prefix): yield name, b def change_tensor_half(): sync_print('override tensor.half() to preserve task_specific flag') # change .half of Tensor ori_tensor_half = torch.Tensor.half torch.Tensor.ori_half = ori_tensor_half def new_half(self, *args, **kwargs): half_t = self.ori_half(*args, **kwargs) if hasattr(self, 'task_specific'): print('preserving task_specific in .half') half_t.task_specific = self.task_specific if hasattr(self, 'modality_share'): print('preserving modality_share in .half') half_t.modality_share = self.modality_share if hasattr(self, 'backbone_specific'): print('preserving backbone_specific in .half') half_t.backbone_specific = self.backbone_specific if hasattr(self, 'adapter_specific'): print('preserving adapter_specific in .half') half_t.adapter_specific = self.adapter_specific if hasattr(self, 'neck_specific'): print('preserving neck_specific in .half') half_t.neck_specific = self.neck_specific if hasattr(self, 'decoder_specific'): print('preserving decoder_specific in .half') half_t.decoder_specific = self.decoder_specific if hasattr(self, 'rgb_specific'): print('preserving rgb_specific in .half') half_t.rgb_specific = self.rgb_specific if hasattr(self, 'dense_label_specific'): print('preserving dense_label_specific in .half') half_t.dense_label_specific = self.dense_label_specific if hasattr(self, 'sparse_label_specific'): print('preserving sparse_label_specific in .half') half_t.sparse_label_specific = self.sparse_label_specific if hasattr(self, 'text_specific'): print('preserving text_specific in .half') half_t.text_specific = self.text_specific if hasattr(self, 'video_specific'): print('preserving video_specific in .half') half_t.video_specific = self.video_specific return half_t torch.Tensor.half = new_half def change_tensor_cuda(): sync_print('override tensor.cuda() to preserve task_specific flag') # change .cuda of Tensor ori_tensor_cuda = torch.Tensor.cuda torch.Tensor.ori_cuda = ori_tensor_cuda def new_cuda(self, *args, **kwargs): # could be written as decorator I guess... cuda_t = self.ori_cuda(*args, **kwargs) if hasattr(self, 'task_specific'): cuda_t.task_specific = self.task_specific if hasattr(self, 'modality_share'): cuda_t.modality_share = self.modality_share if hasattr(self, 'backbone_specific'): cuda_t.backbone_specific = self.backbone_specific if hasattr(self, 'adapter_specific'): cuda_t.adapter_specific = self.adapter_specific if hasattr(self, 'neck_specific'): cuda_t.neck_specific = self.neck_specific if hasattr(self, 'decoder_specific'): cuda_t.decoder_specific = self.decoder_specific if hasattr(self, 'rgb_specific'): cuda_t.rgb_specific = self.rgb_specific if hasattr(self, 'dense_labeling_specific'): cuda_t.dense_labeling_specific = self.dense_labeling_specific if hasattr(self, 'sparse_labeling_specific'): cuda_t.sparse_labeling_specific = self.sparse_labeling_specific if hasattr(self, 'text_specific'): cuda_t.text_specific = self.text_specific if hasattr(self, 'video_specific'): cuda_t.video_specific = self.video_specific return cuda_t torch.Tensor.cuda = new_cuda def add_task_specific(m, task_specific): for name, param in m.named_parameters(): param.task_specific = task_specific param.backbone_specific = False param.neck_specific = False param.decoder_specific = False if task_specific: printlog('add param {} as task_specific'.format(name)) if not hasattr(torch.nn.Module, 'named_buffers'): printlog('registering named_buffers for nn.Module at add_task_specific') torch.nn.Module.named_buffers = named_buffers #m.cuda() # neccesary for broadcast in DistModule, # since buffers are tensors which will be changed after .cuda() for name, buffer in m.named_buffers(): buffer.task_specific = task_specific buffer.backbone_specific = False buffer.neck_specific = False buffer.decoder_specific = False if task_specific: printlog('add buffer {} as task_specific'.format(name)) def add_backbone_specific(m, backbone_specific): for name, param in m.named_parameters(): param.task_specific = False param.backbone_specific = backbone_specific param.neck_specific = False param.decoder_specific = False if backbone_specific: printlog('add param {} as backbone_specific'.format(name)) if not hasattr(torch.nn.Module, 'named_buffers'): printlog('registering named_buffers for nn.Module at add_backbone_specific') torch.nn.Module.named_buffers = named_buffers #m.cuda() # neccesary for broadcast in DistModule, since buffers are tensors which will be changed after .cuda() for name, buffer in m.named_buffers(): buffer.task_specific = False buffer.backbone_specific = backbone_specific buffer.neck_specific = False buffer.decoder_specific = False if backbone_specific: printlog('add buffer {} as backbone_specific'.format(name)) def add_neck_specific(m, neck_specific): for name, param in m.named_parameters(): param.task_specific = False param.backbone_specific = False param.neck_specific = neck_specific param.decoder_specific = False if neck_specific: printlog('add param {} as neck_specific'.format(name)) if not hasattr(torch.nn.Module, 'named_buffers'): printlog('registering named_buffers for nn.Module at add_neck_specific') torch.nn.Module.named_buffers = named_buffers #m.cuda() # neccesary for broadcast in DistModule, since buffers are tensors which will be changed after .cuda() for name, buffer in m.named_buffers(): buffer.task_specific = False buffer.backbone_specific = False buffer.neck_specific = neck_specific buffer.decoder_specific = False if neck_specific: printlog('add buffer {} as neck_specific'.format(name)) def add_decoder_specific(m, decoder_specific): for name, param in m.named_parameters(): param.task_specific = False param.backbone_specific = False param.neck_specific = False param.decoder_specific = decoder_specific if decoder_specific: printlog('add param {} as decoder_specific'.format(name)) if not hasattr(torch.nn.Module, 'named_buffers'): printlog('registering named_buffers for nn.Module at add_decoder_specific') torch.nn.Module.named_buffers = named_buffers #m.cuda() # neccesary for broadcast in DistModule, since buffers are tensors which will be changed after .cuda() for name, buffer in m.named_buffers(): buffer.task_specific = False buffer.backbone_specific = False buffer.neck_specific = False buffer.decoder_specific = decoder_specific if decoder_specific: printlog('add buffer {} as decoder_specific'.format(name)) def add_aiov2_decoder_specific(m, decoder_specific, task_sp_list=(), neck_sp_list=(), modality_share_list=()): for name, param in m.named_parameters(): _task_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in task_sp_list) _neck_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in neck_sp_list) _modality_share_flag = any(name.startswith(share_name) or name.endswith(share_name) for share_name in modality_share_list) param.task_specific = _task_sp_flag param.modality_share = False if _task_sp_flag or _neck_sp_flag else _modality_share_flag param.backbone_specific = False param.rgb_specific = False param.dense_labeling_specific = False param.text_specific = False param.video_specific = False param.sparse_labeling_specific = False param.decoder_specific = False if _task_sp_flag or _neck_sp_flag or _modality_share_flag else decoder_specific if _task_sp_flag: printlog('add param {} as task_specific'.format(name)) elif _neck_sp_flag: printlog('add param {} as neck_specific'.format(name)) elif _modality_share_flag: printlog('add param {} as modality_share'.format(name)) elif decoder_specific: printlog('add param {} as decoder_specific'.format(name)) if not hasattr(torch.nn.Module, 'named_buffers'): printlog('registering named_buffers for nn.Module at add_decoder_specific') torch.nn.Module.named_buffers = named_buffers #m.cuda() # neccesary for broadcast in DistModule, since buffers are tensors which will be changed after .cuda() for name, buffer in m.named_buffers(): _task_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in task_sp_list) _neck_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in neck_sp_list) _modality_share_flag = any(name.startswith(share_name) or name.endswith(share_name) for share_name in modality_share_list) buffer.task_specific = _task_sp_flag buffer.modality_share = False if _task_sp_flag or _neck_sp_flag else _modality_share_flag buffer.backbone_specific = False buffer.rgb_specific = False buffer.dense_labeling_specific = False buffer.text_specific = False buffer.video_specific = False buffer.sparse_labeling_specific = False buffer.decoder_specific = False if _task_sp_flag or _neck_sp_flag or _modality_share_flag else decoder_specific if _task_sp_flag: printlog('add buffer {} as task_specific'.format(name)) elif _neck_sp_flag: printlog('add buffer {} as neck_specific'.format(name)) elif _modality_share_flag: printlog('add buffer {} as modality_share'.format(name)) elif decoder_specific: printlog('add buffer {} as decoder_specific'.format(name)) def add_aiov2_backbone_specific(m, backbone_specific, task_sp_list=(), neck_sp_list=(), modality_share_list=()): for name, param in m.named_parameters(): _task_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in task_sp_list) _neck_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in neck_sp_list) param.task_specific = _task_sp_flag param.modality_share = False param.backbone_specific = False if _task_sp_flag or _neck_sp_flag else backbone_specific param.rgb_specific = False param.dense_labeling_specific = False param.text_specific = False param.video_specific = False param.sparse_labeling_specific = False param.decoder_specific = False if _task_sp_flag: printlog('add param {} as task_specific'.format(name)) elif _neck_sp_flag: printlog('add param {} as neck_specific'.format(name)) elif backbone_specific: printlog('add param {} as backbone_specific'.format(name)) if not hasattr(torch.nn.Module, 'named_buffers'): printlog('registering named_buffers for nn.Module at add_backbone_specific') torch.nn.Module.named_buffers = named_buffers #m.cuda() # neccesary for broadcast in DistModule, since buffers are tensors which will be changed after .cuda() for name, buffer in m.named_buffers(): _task_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in task_sp_list) _neck_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in neck_sp_list) buffer.task_specific = _task_sp_flag buffer.modality_share = False buffer.backbone_specific = False if _task_sp_flag or _neck_sp_flag else backbone_specific buffer.rgb_specific = False buffer.dense_labeling_specific = False buffer.text_specific = False buffer.video_specific = False buffer.sparse_labeling_specific = False buffer.decoder_specific = False if _task_sp_flag: printlog('add buffer {} as task_specific'.format(name)) elif _neck_sp_flag: printlog('add buffer {} as neck_specific'.format(name)) elif backbone_specific: printlog('add buffer {} as backbone_specific'.format(name)) def add_aiov2_task_specific(m, task_specific=True): for name, param in m.named_parameters(): param.task_specific = task_specific param.modality_share = False param.backbone_specific = False param.rgb_specific = False param.dense_labeling_specific = False param.text_specific = False param.video_specific = False param.sparse_labeling_specific = False param.decoder_specific = False printlog('add param {} as task_specific'.format(name)) if not hasattr(torch.nn.Module, 'named_buffers'): printlog('registering named_buffers for nn.Module at add_task_specific') torch.nn.Module.named_buffers = named_buffers #m.cuda() # neccesary for broadcast in DistModule, since buffers are tensors which will be changed after .cuda() for name, buffer in m.named_buffers(): buffer.task_specific = task_specific buffer.modality_share = False buffer.backbone_specific = False buffer.rgb_specific = False buffer.dense_labeling_specific = False buffer.text_specific = False buffer.video_specific = False buffer.sparse_labeling_specific = False buffer.decoder_specific = False printlog('add param {} as task_specific'.format(name)) def param_specific_setting_with_modality(param, modality, _task_sp_flag, _modality_share_flag, modality_specific): param.rgb_specific = False param.dense_labeling_specific = False param.text_specific = False param.video_specific = False param.sparse_labeling_specific = False if modality == 'rgb': param.rgb_specific = False if _task_sp_flag or _modality_share_flag else modality_specific elif modality == 'dense_labeling': param.dense_labeling_specific = False if _task_sp_flag or _modality_share_flag else modality_specific elif modality == 'sparse_labeling': param.sparse_labeling_specific = False if _task_sp_flag or _modality_share_flag else modality_specific elif modality == 'video': param.video_specific = False if _task_sp_flag or _modality_share_flag else modality_specific elif modality == 'text': param.text_specific = False if _task_sp_flag or _modality_share_flag else modality_specific return param def add_aiov2_modality_specific(m, modality, modality_specific, task_sp_list=(), modality_share_list=()): for name, param in m.named_parameters(): _task_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in task_sp_list) _modality_share_flag = any(name.startswith(share_name) or name.endswith(share_name) for share_name in modality_share_list) param.task_specific = _task_sp_flag param.modality_share = False if _task_sp_flag else _modality_share_flag param.backbone_specific = False param.decoder_specific = False # import pdb;pdb.set_trace() param = param_specific_setting_with_modality(param, modality, _task_sp_flag, _modality_share_flag, modality_specific) # pdb.set_trace() # if name.startswith'rgb': # param.adapter_rgb_specific = True if _task_sp_flag: printlog('add param {} as task_specific'.format(name)) elif _modality_share_flag: printlog('add param {} as modality_share'.format(name)) elif modality_specific: printlog('add param {} as {}_specific'.format(name, modality)) if not hasattr(torch.nn.Module, 'named_buffers'): printlog('registering named_buffers for nn.Module at add_adapter_specific') torch.nn.Module.named_buffers = named_buffers #m.cuda() # neccesary for broadcast in DistModule, since buffers are tensors which will be changed after .cuda() for name, buffer in m.named_buffers(): _task_sp_flag = any(name.startswith(sp_name) or name.endswith(sp_name) for sp_name in task_sp_list) _modality_share_flag = any(name.startswith(share_name) or name.endswith(share_name) for share_name in modality_share_list) buffer.task_specific = _task_sp_flag buffer.modality_share = False if _task_sp_flag else _modality_share_flag buffer.backbone_specific = False buffer.decoder_specific = False buffer = param_specific_setting_with_modality(buffer, modality, _task_sp_flag, _modality_share_flag, modality_specific) if _task_sp_flag: printlog('add buffer {} as task_specific'.format(name)) elif _modality_share_flag: printlog('add buffer {} as modality_share'.format(name)) elif modality_specific: printlog('add buffer {} as {}_specific'.format(name, modality)) def copy_state_dict_cpu(state_dict): new_state = {} for k,v in state_dict.items(): new_state[k] = v.cpu() return new_state def copy_optim_state_dict_cpu(state_dict): new_state = {} new_state['param_groups'] = copy.deepcopy(state_dict['param_groups']) new_state['state'] = {} for k,v in state_dict['state'].items(): new_state['state'][k] = {} for name,x in v.items(): if isinstance(x, torch.Tensor): new_state['state'][k][name] = x.cpu() else: new_state['state'][k][name] = copy.deepcopy(x) return new_state def copy_optim_state_dict_cpu_fp16(state_dict): new_state = {} new_state['optimizer_state_dict'] = copy_optim_state_dict_cpu(state_dict['optimizer_state_dict']) for k in state_dict.keys(): if k != 'optimizer_state_dict': new_state[k] = copy.deepcopy(state_dict[k]) return new_state def sync_print(*args, **kwargs): if not dist.is_initialized(): print(*args, **kwargs) else: rank = dist.get_rank() # link.barrier() print('sync_print: rank {}, '.format(rank) + ' '.join(args), **kwargs) def fully_checkpoint_sequential(functions, segments, input, **kwargs): r"""Modified version of torch.utils.checkpoint.checkpoint_sequential for memory efficiency. It is assumed that at least one of the inputs have requires_grad=True, so we can checkpoint all of the segments at ease. Please refer to https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint_sequential for more details. -1 -> sqrt chunk checkpoint 0 -> no checkpoint others -> """ preserve = kwargs.pop('preserve_rng_state', True) if kwargs: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) def run_function(start, end, functions): def forward(input): for j in range(start, end + 1): input = functions[j](input) return input return forward if isinstance(functions, torch.nn.Sequential): functions = list(functions.children()) # no checkpoint if segments == 0: return run_function(0, len(functions) - 1, functions)(input) # auto determin the chunksize if segments < 0: segments = int(math.ceil(len(functions))) segments = min(segments, len(functions)) segment_size = len(functions) // segments # the last chunk has to be non-volatile end = -1 for start in range(0, segment_size * (segments - 1), segment_size): end = start + segment_size - 1 input = checkpoint(run_function(start, end, functions), input) # preserve_rng_state=preserve) return checkpoint(run_function(end + 1, len(functions) - 1, functions), input)#, # preserve_rng_state=preserve) def printlog(*args, **kwargs): if not dist.is_initialized(): print(*args, **kwargs) else: print(f"[rank {dist.get_rank()}]", *args, **kwargs) def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask def cuda(self): return self.to('cuda') def __repr__(self): return str(self.tensors) def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # import pdb;pdb.set_trace() ## for nested tensor debug # max_size = [3, 816, 816] # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], :img.shape[2]] = False # 0: content, 1: pad else: raise ValueError('not supported') return NestedTensor(tensor, mask) def get_num_layer_for_vit(var_name, config): if (var_name == "module.backbone_module" or var_name.endswith("prompt_embed_kv")) and config.get('lpe_lr', False): # for PE. return config.num_layers - 1 if var_name in ("module.backbone_module", "module.backbone_module.cls_token", "module.backbone_module.mask_token"): return 0 elif var_name.startswith("module.backbone_module.patch_embed"): return 0 elif var_name.startswith("module.backbone_module") and not (var_name.startswith("module.backbone_module.norm") or var_name.startswith("module.backbone_module.ln_pre")): layer_id = int(var_name.split('.')[3]) return layer_id + 1 else: return config.num_layers - 1 def get_num_layer_for_vit_with_adapter(var_name, var_param_name, config): # import pdb;pdb.set_trace() if (var_name.startswith("module.adapter_") and var_param_name=='pos_embed' and len(var_name.split('.'))==2) and config.get('lpe_lr', False): # for PE. return config.num_layers - 1 # if var_name in ("module.backbone_module", "module.backbone_module.cls_token", "module.backbone_module.mask_token"): # return 0 elif var_name.startswith("module.adapter_") : return 0 # elif var_name.startswith("module.backbone_module") and not (var_name.startswith("module.backbone_module.norm") or # var_name.startswith("module.backbone_module.ln_pre")): # layer_id = int(var_name.split('.')[3]) # return layer_id + 1 # else: # return config.num_layers - 1 def nested_tensor_from_tensor_list_fix_shape(tensor_list: List[Tensor],max=1333,short=800,idx=None): # TODO make this more general if tensor_list[0].ndim == 3: # TODO make it support different-sized images # # for coco, resize to 1333, 800 _, _h, _w = tensor_list[0].shape if _w > _h: max_size = [3, short, max] else: max_size = [3, max, short] # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], :img.shape[2]] = False else: raise ValueError('not supported') return NestedTensor(tensor, mask)