# Copyright (c) OpenMMLab. All rights reserved.import math import json import math import torch import torch.nn as nn import numpy as np from collections import defaultdict from utils import get_root_logger import torch.nn.functional as F def rearrange_activations(activations): n_channels = activations.shape[-1] activations = activations.reshape(-1, n_channels) return activations def ps_inv(x1, x2): '''Least-squares solver given feature maps from two anchors. ''' x1 = rearrange_activations(x1) x2 = rearrange_activations(x2) if not x1.shape[0] == x2.shape[0]: raise ValueError('Spatial size of compared neurons must match when ' \ 'calculating psuedo inverse matrix.') # Get transformation matrix shape shape = list(x1.shape) shape[-1] += 1 # Calculate pseudo inverse x1_ones = torch.ones(shape) x1_ones[:, :-1] = x1 A_ones = torch.matmul(torch.linalg.pinv(x1_ones), x2.to(x1_ones.device)).T # Get weights and bias w = A_ones[..., :-1] b = A_ones[..., -1] return w, b def reset_out_indices(front_depth=12, end_depth=24, out_indices=(9, 14, 19, 23)): block_ids = torch.tensor(list(range(front_depth))) block_ids = block_ids[None, None, :].float() end_mapping_ids = torch.nn.functional.interpolate(block_ids, end_depth) end_mapping_ids = end_mapping_ids.squeeze().long().tolist() small_out_indices = [] for i, idx in enumerate(end_mapping_ids): if i in out_indices: small_out_indices.append(idx) return small_out_indices def get_stitch_configs_general_unequal(depths): depths = sorted(depths) total_configs = [] # anchor configurations total_configs.append({'comb_id': [1], }) num_stitches = depths[0] for i, blk_id in enumerate(range(num_stitches)): total_configs.append({ 'comb_id': (0, 1), 'stitch_cfgs': (i, (i + 1) * (depths[1] // depths[0])) }) return total_configs, num_stitches def get_stitch_configs_bidirection(depths): depths = sorted(depths) total_configs = [] # anchor configurations total_configs.append({'comb_id': [0], }) total_configs.append({'comb_id': [1], }) num_stitches = depths[0] # small --> large sl_configs = [] for i, blk_id in enumerate(range(num_stitches)): sl_configs.append({ 'comb_id': [0, 1], 'stitch_cfgs': [ [i, (i + 1) * (depths[1] // depths[0])] ], 'stitch_layer_ids': [i] }) ls_configs = [] lsl_confgs = [] block_ids = torch.tensor(list(range(depths[0]))) block_ids = block_ids[None, None, :].float() end_mapping_ids = torch.nn.functional.interpolate(block_ids, depths[1]) end_mapping_ids = end_mapping_ids.squeeze().long().tolist() # large --> small for i in range(depths[1]): if depths[1] != depths[0]: if i % 2 == 1 and i < (depths[1] - 1): ls_configs.append({ 'comb_id': [1, 0], 'stitch_cfgs': [[i, end_mapping_ids[i] + 1]], 'stitch_layer_ids': [i // (depths[1] // depths[0])] }) else: if i < (depths[1] - 1): ls_configs.append({ 'comb_id': [1, 0], 'stitch_cfgs': [[i, end_mapping_ids[i] + 1]], 'stitch_layer_ids': [i // (depths[1] // depths[0])] }) # large --> small --> large for ls_cfg in ls_configs: for sl_cfg in sl_configs: if sl_cfg['stitch_layer_ids'][0] == depths[0] - 1: continue if sl_cfg['stitch_cfgs'][0][0] >= ls_cfg['stitch_cfgs'][0][1]: lsl_confgs.append({ 'comb_id': [1, 0, 1], 'stitch_cfgs': [ls_cfg['stitch_cfgs'][0], sl_cfg['stitch_cfgs'][0]], 'stitch_layer_ids': ls_cfg['stitch_layer_ids'] + sl_cfg['stitch_layer_ids'] }) # small --> large --> small sls_configs = [] for sl_cfg in sl_configs: for ls_cfg in ls_configs: if ls_cfg['stitch_cfgs'][0][0] >= sl_cfg['stitch_cfgs'][0][1]: sls_configs.append({ 'comb_id': [0, 1, 0], 'stitch_cfgs': [sl_cfg['stitch_cfgs'][0], ls_cfg['stitch_cfgs'][0]], 'stitch_layer_ids': sl_cfg['stitch_layer_ids'] + ls_cfg['stitch_layer_ids'] }) total_configs += sl_configs + ls_configs + lsl_confgs + sls_configs anchor_ids = [] sl_ids = [] ls_ids = [] lsl_ids = [] sls_ids = [] for i, cfg in enumerate(total_configs): comb_id = cfg['comb_id'] if len(comb_id) == 1: anchor_ids.append(i) continue if len(comb_id) == 2: route = [] front, end = cfg['stitch_cfgs'][0] route.append([0, front]) route.append([end, depths[comb_id[-1]]]) cfg['route'] = route if comb_id == [0, 1] and front != 11: sl_ids.append(i) elif comb_id == [1, 0]: ls_ids.append(i) if len(comb_id) == 3: route = [] front_1, end_1 = cfg['stitch_cfgs'][0] front_2, end_2 = cfg['stitch_cfgs'][1] route.append([0, front_1]) route.append([end_1, front_2]) route.append([end_2, depths[comb_id[-1]]]) cfg['route'] = route if comb_id == [1, 0, 1]: lsl_ids.append(i) elif comb_id == [0, 1, 0]: sls_ids.append(i) cfg['stitch_layer_ids'].append(-1) model_combos = [(0, 1), (1, 0)] return total_configs, model_combos, [len(sl_configs), len(ls_configs)], anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids def format_out_features(outs, with_cls_token, hw_shape): B, _, C = outs[0].shape for i in range(len(outs)): if with_cls_token: # Remove class token and reshape token for decoder head outs[i] = outs[i][:, 1:].reshape(B, hw_shape[0], hw_shape[1], C).permute(0, 3, 1, 2).contiguous() else: outs[i] = outs[i].reshape(B, hw_shape[0], hw_shape[1], C).permute(0, 3, 1, 2).contiguous() return outs class LoRALayer(): def __init__( self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): self.r = r self.lora_alpha = lora_alpha # Optional dropout if lora_dropout > 0.: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False self.merge_weights = merge_weights class Linear(nn.Linear, LoRALayer): # LoRA implemented in a dense layer def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) merge_weights: bool = True, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1) def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w nn.Linear.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling self.merged = True def forward(self, x: torch.Tensor): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w if self.r > 0 and not self.merged: result = F.linear(x, T(self.weight), bias=self.bias) if self.r > 0: result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling return result else: return F.linear(x, T(self.weight), bias=self.bias) class StitchingLayer(nn.Module): def __init__(self, in_features=None, out_features=None, r=0): super().__init__() self.transform = Linear(in_features, out_features, r=r) def init_stitch_weights_bias(self, weight, bias): self.transform.weight.data.copy_(weight) self.transform.bias.data.copy_(bias) def forward(self, x): out = self.transform(x) return out class SNNet(nn.Module): def __init__(self, anchors=None): super(SNNet, self).__init__() self.anchors = nn.ModuleList(anchors) self.depths = [len(anc.blocks) for anc in self.anchors] total_configs, num_stitches = get_stitch_configs_general_unequal(self.depths) self.stitch_layers = nn.ModuleList( [StitchingLayer(self.anchors[0].embed_dim, self.anchors[1].embed_dim) for _ in range(num_stitches)]) self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)} self.all_cfgs = list(self.stitch_configs.keys()) self.num_configs = len(self.all_cfgs) self.stitch_config_id = 0 self.is_ranking = False def reset_stitch_id(self, stitch_config_id): self.stitch_config_id = stitch_config_id def initialize_stitching_weights(self, x): logger = get_root_logger() front, end = 0, 1 with torch.no_grad(): front_features = self.anchors[front].extract_block_features(x) end_features = self.anchors[end].extract_block_features(x) for i, blk_id in enumerate(range(self.depths[0])): front_id, end_id = i, (i + 1) * (self.depths[1] // self.depths[0]) front_blk_feat = front_features[front_id] end_blk_feat = end_features[end_id - 1] w, b = ps_inv(front_blk_feat, end_blk_feat) self.stitch_layers[i].init_stitch_weights_bias(w, b) logger.info(f'Initialized Stitching Model {front} to Model {end}, Layer {i}') def init_weights(self): for anc in self.anchors: anc.init_weights() def sampling_stitch_config(self): self.stitch_config_id = np.random.choice(self.all_cfgs) def forward(self, x): stitch_cfg_id = self.stitch_config_id comb_id = self.stitch_configs[stitch_cfg_id]['comb_id'] if len(comb_id) == 1: return self.anchors[comb_id[0]](x) cfg = self.stitch_configs[stitch_cfg_id]['stitch_cfgs'] x = self.anchors[comb_id[0]].forward_until(x, blk_id=cfg[0]) x = self.stitch_layers[cfg[0]](x) x = self.anchors[comb_id[1]].forward_from(x, blk_id=cfg[1]) return x class SNNetv2(nn.Module): def __init__(self, anchors=None, include_sl=True, include_ls=True, include_lsl=True, include_sls=True, lora_r=0): super(SNNetv2, self).__init__() self.anchors = nn.ModuleList(anchors) self.lora_r = lora_r self.depths = [len(anc.blocks) for anc in self.anchors] total_configs, model_combos, num_stitches, anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids = get_stitch_configs_bidirection(self.depths) self.stitch_layers = nn.ModuleList() self.stitching_map_id = {} for i, (comb, num_sth) in enumerate(zip(model_combos, num_stitches)): front, end = comb temp = nn.ModuleList( [StitchingLayer(self.anchors[front].embed_dim, self.anchors[end].embed_dim, r=lora_r) for _ in range(num_sth)]) temp.append(nn.Identity()) self.stitch_layers.append(temp) self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)} self.stitch_init_configs = {i: cfg for i, cfg in enumerate(total_configs) if len(cfg['comb_id']) == 2} self.all_cfgs = list(self.stitch_configs.keys()) logger = get_root_logger() logger.info(str(self.all_cfgs)) self.all_cfgs = anchor_ids if include_sl: self.all_cfgs += sl_ids if include_ls: self.all_cfgs += ls_ids if include_lsl: self.all_cfgs += lsl_ids if include_sls: self.all_cfgs += sls_ids self.num_configs = len(self.stitch_configs) self.stitch_config_id = 0 def reset_stitch_id(self, stitch_config_id): self.stitch_config_id = stitch_config_id def set_ranking_mode(self, ranking_mode): self.is_ranking = ranking_mode def initialize_stitching_weights(self, x): logger = get_root_logger() anchor_features = [] for anchor in self.anchors: with torch.no_grad(): temp = anchor.extract_block_features(x) anchor_features.append(temp) for idx, cfg in self.stitch_init_configs.items(): comb_id = cfg['comb_id'] if len(comb_id) == 2: front_id, end_id = cfg['stitch_cfgs'][0] stitch_layer_id = cfg['stitch_layer_ids'][0] front_blk_feat = anchor_features[comb_id[0]][front_id] end_blk_feat = anchor_features[comb_id[1]][end_id - 1] w, b = ps_inv(front_blk_feat, end_blk_feat) self.stitch_layers[comb_id[0]][stitch_layer_id].init_stitch_weights_bias(w, b) logger.info(f'Initialized Stitching Layer {cfg}') def init_weights(self): for anc in self.anchors: anc.init_weights() def sampling_stitch_config(self): flops_id = np.random.choice(len(self.flops_grouped_cfgs), p=self.flops_sampling_probs) stitch_config_id = np.random.choice(self.flops_grouped_cfgs[flops_id]) return stitch_config_id def forward(self, x): if self.training: stitch_cfg_id = self.sampling_stitch_config() else: stitch_cfg_id = self.stitch_config_id comb_id = self.stitch_configs[stitch_cfg_id]['comb_id'] # forward by a single anchor if len(comb_id) == 1: return self.anchors[comb_id[0]](x) # forward among anchors route = self.stitch_configs[stitch_cfg_id]['route'] stitch_layer_ids = self.stitch_configs[stitch_cfg_id]['stitch_layer_ids'] # patch embeding x = self.anchors[comb_id[0]].forward_patch_embed(x) for i, (model_id, cfg) in enumerate(zip(comb_id, route)): x = self.anchors[model_id].selective_forward(x, cfg[0], cfg[1]) x = self.stitch_layers[model_id][stitch_layer_ids[i]](x) x = self.anchors[comb_id[-1]].forward_norm_head(x) return x