|
|
|
import json |
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
import numpy as np |
|
|
|
from collections import defaultdict |
|
from utils import get_root_logger |
|
import torch.nn.functional as F |
|
|
|
def rearrange_activations(activations): |
|
n_channels = activations.shape[-1] |
|
activations = activations.reshape(-1, n_channels) |
|
return activations |
|
|
|
|
|
def ps_inv(x1, x2): |
|
'''Least-squares solver given feature maps from two anchors. |
|
''' |
|
x1 = rearrange_activations(x1) |
|
x2 = rearrange_activations(x2) |
|
|
|
if not x1.shape[0] == x2.shape[0]: |
|
raise ValueError('Spatial size of compared neurons must match when ' \ |
|
'calculating psuedo inverse matrix.') |
|
|
|
|
|
shape = list(x1.shape) |
|
shape[-1] += 1 |
|
|
|
|
|
x1_ones = torch.ones(shape) |
|
x1_ones[:, :-1] = x1 |
|
A_ones = torch.matmul(torch.linalg.pinv(x1_ones), x2.to(x1_ones.device)).T |
|
|
|
|
|
w = A_ones[..., :-1] |
|
b = A_ones[..., -1] |
|
|
|
return w, b |
|
|
|
|
|
def reset_out_indices(front_depth=12, end_depth=24, out_indices=(9, 14, 19, 23)): |
|
block_ids = torch.tensor(list(range(front_depth))) |
|
block_ids = block_ids[None, None, :].float() |
|
end_mapping_ids = torch.nn.functional.interpolate(block_ids, end_depth) |
|
end_mapping_ids = end_mapping_ids.squeeze().long().tolist() |
|
|
|
small_out_indices = [] |
|
for i, idx in enumerate(end_mapping_ids): |
|
if i in out_indices: |
|
small_out_indices.append(idx) |
|
|
|
return small_out_indices |
|
|
|
|
|
def get_stitch_configs_general_unequal(depths): |
|
depths = sorted(depths) |
|
|
|
total_configs = [] |
|
|
|
|
|
total_configs.append({'comb_id': [1], }) |
|
num_stitches = depths[0] |
|
for i, blk_id in enumerate(range(num_stitches)): |
|
total_configs.append({ |
|
'comb_id': (0, 1), |
|
'stitch_cfgs': (i, (i + 1) * (depths[1] // depths[0])) |
|
}) |
|
return total_configs, num_stitches |
|
|
|
def get_stitch_configs_bidirection(depths): |
|
depths = sorted(depths) |
|
|
|
total_configs = [] |
|
|
|
|
|
total_configs.append({'comb_id': [0], }) |
|
total_configs.append({'comb_id': [1], }) |
|
|
|
num_stitches = depths[0] |
|
|
|
|
|
sl_configs = [] |
|
for i, blk_id in enumerate(range(num_stitches)): |
|
sl_configs.append({ |
|
'comb_id': [0, 1], |
|
'stitch_cfgs': [ |
|
[i, (i + 1) * (depths[1] // depths[0])] |
|
], |
|
'stitch_layer_ids': [i] |
|
}) |
|
|
|
ls_configs = [] |
|
lsl_confgs = [] |
|
block_ids = torch.tensor(list(range(depths[0]))) |
|
block_ids = block_ids[None, None, :].float() |
|
end_mapping_ids = torch.nn.functional.interpolate(block_ids, depths[1]) |
|
end_mapping_ids = end_mapping_ids.squeeze().long().tolist() |
|
|
|
|
|
for i in range(depths[1]): |
|
if depths[1] != depths[0]: |
|
if i % 2 == 1 and i < (depths[1] - 1): |
|
ls_configs.append({ |
|
'comb_id': [1, 0], |
|
'stitch_cfgs': [[i, end_mapping_ids[i] + 1]], |
|
'stitch_layer_ids': [i // (depths[1] // depths[0])] |
|
}) |
|
else: |
|
if i < (depths[1] - 1): |
|
ls_configs.append({ |
|
'comb_id': [1, 0], |
|
'stitch_cfgs': [[i, end_mapping_ids[i] + 1]], |
|
'stitch_layer_ids': [i // (depths[1] // depths[0])] |
|
}) |
|
|
|
|
|
|
|
for ls_cfg in ls_configs: |
|
for sl_cfg in sl_configs: |
|
if sl_cfg['stitch_layer_ids'][0] == depths[0] - 1: |
|
continue |
|
if sl_cfg['stitch_cfgs'][0][0] >= ls_cfg['stitch_cfgs'][0][1]: |
|
lsl_confgs.append({ |
|
'comb_id': [1, 0, 1], |
|
'stitch_cfgs': [ls_cfg['stitch_cfgs'][0], sl_cfg['stitch_cfgs'][0]], |
|
'stitch_layer_ids': ls_cfg['stitch_layer_ids'] + sl_cfg['stitch_layer_ids'] |
|
}) |
|
|
|
|
|
sls_configs = [] |
|
for sl_cfg in sl_configs: |
|
for ls_cfg in ls_configs: |
|
if ls_cfg['stitch_cfgs'][0][0] >= sl_cfg['stitch_cfgs'][0][1]: |
|
sls_configs.append({ |
|
'comb_id': [0, 1, 0], |
|
'stitch_cfgs': [sl_cfg['stitch_cfgs'][0], ls_cfg['stitch_cfgs'][0]], |
|
'stitch_layer_ids': sl_cfg['stitch_layer_ids'] + ls_cfg['stitch_layer_ids'] |
|
}) |
|
|
|
total_configs += sl_configs + ls_configs + lsl_confgs + sls_configs |
|
|
|
anchor_ids = [] |
|
sl_ids = [] |
|
ls_ids = [] |
|
lsl_ids = [] |
|
sls_ids = [] |
|
|
|
for i, cfg in enumerate(total_configs): |
|
comb_id = cfg['comb_id'] |
|
|
|
if len(comb_id) == 1: |
|
anchor_ids.append(i) |
|
continue |
|
|
|
if len(comb_id) == 2: |
|
route = [] |
|
front, end = cfg['stitch_cfgs'][0] |
|
route.append([0, front]) |
|
route.append([end, depths[comb_id[-1]]]) |
|
cfg['route'] = route |
|
|
|
if comb_id == [0, 1] and front != 11: |
|
sl_ids.append(i) |
|
elif comb_id == [1, 0]: |
|
ls_ids.append(i) |
|
|
|
if len(comb_id) == 3: |
|
route = [] |
|
front_1, end_1 = cfg['stitch_cfgs'][0] |
|
front_2, end_2 = cfg['stitch_cfgs'][1] |
|
route.append([0, front_1]) |
|
route.append([end_1, front_2]) |
|
route.append([end_2, depths[comb_id[-1]]]) |
|
cfg['route'] = route |
|
|
|
if comb_id == [1, 0, 1]: |
|
lsl_ids.append(i) |
|
elif comb_id == [0, 1, 0]: |
|
sls_ids.append(i) |
|
|
|
cfg['stitch_layer_ids'].append(-1) |
|
|
|
model_combos = [(0, 1), (1, 0)] |
|
return total_configs, model_combos, [len(sl_configs), len(ls_configs)], anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids |
|
|
|
|
|
def format_out_features(outs, with_cls_token, hw_shape): |
|
B, _, C = outs[0].shape |
|
for i in range(len(outs)): |
|
if with_cls_token: |
|
|
|
outs[i] = outs[i][:, 1:].reshape(B, hw_shape[0], hw_shape[1], |
|
C).permute(0, 3, 1, 2).contiguous() |
|
else: |
|
outs[i] = outs[i].reshape(B, hw_shape[0], hw_shape[1], |
|
C).permute(0, 3, 1, 2).contiguous() |
|
return outs |
|
|
|
|
|
class LoRALayer(): |
|
def __init__( |
|
self, |
|
r: int, |
|
lora_alpha: int, |
|
lora_dropout: float, |
|
merge_weights: bool, |
|
): |
|
self.r = r |
|
self.lora_alpha = lora_alpha |
|
|
|
if lora_dropout > 0.: |
|
self.lora_dropout = nn.Dropout(p=lora_dropout) |
|
else: |
|
self.lora_dropout = lambda x: x |
|
|
|
self.merged = False |
|
self.merge_weights = merge_weights |
|
|
|
class Linear(nn.Linear, LoRALayer): |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0., |
|
fan_in_fan_out: bool = False, |
|
merge_weights: bool = True, |
|
**kwargs |
|
): |
|
nn.Linear.__init__(self, in_features, out_features, **kwargs) |
|
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
|
merge_weights=merge_weights) |
|
|
|
self.fan_in_fan_out = fan_in_fan_out |
|
|
|
if r > 0: |
|
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) |
|
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) |
|
self.scaling = self.lora_alpha / self.r |
|
|
|
self.weight.requires_grad = False |
|
self.reset_parameters() |
|
if fan_in_fan_out: |
|
self.weight.data = self.weight.data.transpose(0, 1) |
|
|
|
def reset_parameters(self): |
|
nn.Linear.reset_parameters(self) |
|
if hasattr(self, 'lora_A'): |
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
nn.init.zeros_(self.lora_B) |
|
|
|
def train(self, mode: bool = True): |
|
def T(w): |
|
return w.transpose(0, 1) if self.fan_in_fan_out else w |
|
nn.Linear.train(self, mode) |
|
if mode: |
|
if self.merge_weights and self.merged: |
|
|
|
if self.r > 0: |
|
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling |
|
self.merged = False |
|
else: |
|
if self.merge_weights and not self.merged: |
|
|
|
if self.r > 0: |
|
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling |
|
self.merged = True |
|
|
|
def forward(self, x: torch.Tensor): |
|
def T(w): |
|
return w.transpose(0, 1) if self.fan_in_fan_out else w |
|
if self.r > 0 and not self.merged: |
|
result = F.linear(x, T(self.weight), bias=self.bias) |
|
if self.r > 0: |
|
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling |
|
return result |
|
else: |
|
return F.linear(x, T(self.weight), bias=self.bias) |
|
|
|
|
|
class StitchingLayer(nn.Module): |
|
def __init__(self, in_features=None, out_features=None, r=0): |
|
super().__init__() |
|
self.transform = Linear(in_features, out_features, r=r) |
|
|
|
def init_stitch_weights_bias(self, weight, bias): |
|
self.transform.weight.data.copy_(weight) |
|
self.transform.bias.data.copy_(bias) |
|
|
|
def forward(self, x): |
|
out = self.transform(x) |
|
return out |
|
|
|
|
|
class SNNet(nn.Module): |
|
|
|
def __init__(self, anchors=None): |
|
super(SNNet, self).__init__() |
|
self.anchors = nn.ModuleList(anchors) |
|
|
|
self.depths = [len(anc.blocks) for anc in self.anchors] |
|
|
|
total_configs, num_stitches = get_stitch_configs_general_unequal(self.depths) |
|
self.stitch_layers = nn.ModuleList( |
|
[StitchingLayer(self.anchors[0].embed_dim, self.anchors[1].embed_dim) for _ in range(num_stitches)]) |
|
|
|
self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)} |
|
self.all_cfgs = list(self.stitch_configs.keys()) |
|
self.num_configs = len(self.all_cfgs) |
|
self.stitch_config_id = 0 |
|
self.is_ranking = False |
|
|
|
def reset_stitch_id(self, stitch_config_id): |
|
self.stitch_config_id = stitch_config_id |
|
|
|
def initialize_stitching_weights(self, x): |
|
logger = get_root_logger() |
|
front, end = 0, 1 |
|
with torch.no_grad(): |
|
front_features = self.anchors[front].extract_block_features(x) |
|
end_features = self.anchors[end].extract_block_features(x) |
|
|
|
for i, blk_id in enumerate(range(self.depths[0])): |
|
front_id, end_id = i, (i + 1) * (self.depths[1] // self.depths[0]) |
|
front_blk_feat = front_features[front_id] |
|
end_blk_feat = end_features[end_id - 1] |
|
w, b = ps_inv(front_blk_feat, end_blk_feat) |
|
self.stitch_layers[i].init_stitch_weights_bias(w, b) |
|
logger.info(f'Initialized Stitching Model {front} to Model {end}, Layer {i}') |
|
|
|
def init_weights(self): |
|
for anc in self.anchors: |
|
anc.init_weights() |
|
|
|
def sampling_stitch_config(self): |
|
self.stitch_config_id = np.random.choice(self.all_cfgs) |
|
|
|
def forward(self, x): |
|
|
|
stitch_cfg_id = self.stitch_config_id |
|
comb_id = self.stitch_configs[stitch_cfg_id]['comb_id'] |
|
|
|
if len(comb_id) == 1: |
|
return self.anchors[comb_id[0]](x) |
|
|
|
cfg = self.stitch_configs[stitch_cfg_id]['stitch_cfgs'] |
|
|
|
x = self.anchors[comb_id[0]].forward_until(x, blk_id=cfg[0]) |
|
x = self.stitch_layers[cfg[0]](x) |
|
x = self.anchors[comb_id[1]].forward_from(x, blk_id=cfg[1]) |
|
|
|
return x |
|
|
|
|
|
class SNNetv2(nn.Module): |
|
|
|
def __init__(self, anchors=None, include_sl=True, include_ls=True, include_lsl=True, include_sls=True, lora_r=0): |
|
super(SNNetv2, self).__init__() |
|
self.anchors = nn.ModuleList(anchors) |
|
|
|
self.lora_r = lora_r |
|
|
|
self.depths = [len(anc.blocks) for anc in self.anchors] |
|
|
|
total_configs, model_combos, num_stitches, anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids = get_stitch_configs_bidirection(self.depths) |
|
|
|
self.stitch_layers = nn.ModuleList() |
|
self.stitching_map_id = {} |
|
|
|
for i, (comb, num_sth) in enumerate(zip(model_combos, num_stitches)): |
|
front, end = comb |
|
temp = nn.ModuleList( |
|
[StitchingLayer(self.anchors[front].embed_dim, self.anchors[end].embed_dim, r=lora_r) for _ in range(num_sth)]) |
|
temp.append(nn.Identity()) |
|
self.stitch_layers.append(temp) |
|
|
|
self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)} |
|
self.stitch_init_configs = {i: cfg for i, cfg in enumerate(total_configs) if len(cfg['comb_id']) == 2} |
|
|
|
|
|
self.all_cfgs = list(self.stitch_configs.keys()) |
|
logger = get_root_logger() |
|
logger.info(str(self.all_cfgs)) |
|
|
|
|
|
self.all_cfgs = anchor_ids |
|
|
|
if include_sl: |
|
self.all_cfgs += sl_ids |
|
|
|
if include_ls: |
|
self.all_cfgs += ls_ids |
|
|
|
if include_lsl: |
|
self.all_cfgs += lsl_ids |
|
|
|
if include_sls: |
|
self.all_cfgs += sls_ids |
|
|
|
self.num_configs = len(self.stitch_configs) |
|
self.stitch_config_id = 0 |
|
|
|
def reset_stitch_id(self, stitch_config_id): |
|
self.stitch_config_id = stitch_config_id |
|
|
|
def set_ranking_mode(self, ranking_mode): |
|
self.is_ranking = ranking_mode |
|
|
|
def initialize_stitching_weights(self, x): |
|
logger = get_root_logger() |
|
anchor_features = [] |
|
for anchor in self.anchors: |
|
with torch.no_grad(): |
|
temp = anchor.extract_block_features(x) |
|
anchor_features.append(temp) |
|
|
|
for idx, cfg in self.stitch_init_configs.items(): |
|
comb_id = cfg['comb_id'] |
|
if len(comb_id) == 2: |
|
front_id, end_id = cfg['stitch_cfgs'][0] |
|
stitch_layer_id = cfg['stitch_layer_ids'][0] |
|
front_blk_feat = anchor_features[comb_id[0]][front_id] |
|
end_blk_feat = anchor_features[comb_id[1]][end_id - 1] |
|
w, b = ps_inv(front_blk_feat, end_blk_feat) |
|
self.stitch_layers[comb_id[0]][stitch_layer_id].init_stitch_weights_bias(w, b) |
|
logger.info(f'Initialized Stitching Layer {cfg}') |
|
|
|
def init_weights(self): |
|
for anc in self.anchors: |
|
anc.init_weights() |
|
|
|
|
|
def sampling_stitch_config(self): |
|
flops_id = np.random.choice(len(self.flops_grouped_cfgs), p=self.flops_sampling_probs) |
|
stitch_config_id = np.random.choice(self.flops_grouped_cfgs[flops_id]) |
|
return stitch_config_id |
|
|
|
def forward(self, x): |
|
|
|
if self.training: |
|
stitch_cfg_id = self.sampling_stitch_config() |
|
else: |
|
stitch_cfg_id = self.stitch_config_id |
|
|
|
comb_id = self.stitch_configs[stitch_cfg_id]['comb_id'] |
|
|
|
|
|
if len(comb_id) == 1: |
|
return self.anchors[comb_id[0]](x) |
|
|
|
|
|
route = self.stitch_configs[stitch_cfg_id]['route'] |
|
stitch_layer_ids = self.stitch_configs[stitch_cfg_id]['stitch_layer_ids'] |
|
|
|
|
|
x = self.anchors[comb_id[0]].forward_patch_embed(x) |
|
|
|
for i, (model_id, cfg) in enumerate(zip(comb_id, route)): |
|
|
|
x = self.anchors[model_id].selective_forward(x, cfg[0], cfg[1]) |
|
x = self.stitch_layers[model_id][stitch_layer_ids[i]](x) |
|
|
|
x = self.anchors[comb_id[-1]].forward_norm_head(x) |
|
return x |
|
|
|
|