HubHop
update
bcfa144
raw
history blame
16.2 kB
# Copyright (c) OpenMMLab. All rights reserved.import math
import json
import math
import torch
import torch.nn as nn
import numpy as np
from collections import defaultdict
from utils import get_root_logger
import torch.nn.functional as F
def rearrange_activations(activations):
n_channels = activations.shape[-1]
activations = activations.reshape(-1, n_channels)
return activations
def ps_inv(x1, x2):
'''Least-squares solver given feature maps from two anchors.
'''
x1 = rearrange_activations(x1)
x2 = rearrange_activations(x2)
if not x1.shape[0] == x2.shape[0]:
raise ValueError('Spatial size of compared neurons must match when ' \
'calculating psuedo inverse matrix.')
# Get transformation matrix shape
shape = list(x1.shape)
shape[-1] += 1
# Calculate pseudo inverse
x1_ones = torch.ones(shape)
x1_ones[:, :-1] = x1
A_ones = torch.matmul(torch.linalg.pinv(x1_ones), x2.to(x1_ones.device)).T
# Get weights and bias
w = A_ones[..., :-1]
b = A_ones[..., -1]
return w, b
def reset_out_indices(front_depth=12, end_depth=24, out_indices=(9, 14, 19, 23)):
block_ids = torch.tensor(list(range(front_depth)))
block_ids = block_ids[None, None, :].float()
end_mapping_ids = torch.nn.functional.interpolate(block_ids, end_depth)
end_mapping_ids = end_mapping_ids.squeeze().long().tolist()
small_out_indices = []
for i, idx in enumerate(end_mapping_ids):
if i in out_indices:
small_out_indices.append(idx)
return small_out_indices
def get_stitch_configs_general_unequal(depths):
depths = sorted(depths)
total_configs = []
# anchor configurations
total_configs.append({'comb_id': [1], })
num_stitches = depths[0]
for i, blk_id in enumerate(range(num_stitches)):
total_configs.append({
'comb_id': (0, 1),
'stitch_cfgs': (i, (i + 1) * (depths[1] // depths[0]))
})
return total_configs, num_stitches
def get_stitch_configs_bidirection(depths):
depths = sorted(depths)
total_configs = []
# anchor configurations
total_configs.append({'comb_id': [0], })
total_configs.append({'comb_id': [1], })
num_stitches = depths[0]
# small --> large
sl_configs = []
for i, blk_id in enumerate(range(num_stitches)):
sl_configs.append({
'comb_id': [0, 1],
'stitch_cfgs': [
[i, (i + 1) * (depths[1] // depths[0])]
],
'stitch_layer_ids': [i]
})
ls_configs = []
lsl_confgs = []
block_ids = torch.tensor(list(range(depths[0])))
block_ids = block_ids[None, None, :].float()
end_mapping_ids = torch.nn.functional.interpolate(block_ids, depths[1])
end_mapping_ids = end_mapping_ids.squeeze().long().tolist()
# large --> small
for i in range(depths[1]):
if depths[1] != depths[0]:
if i % 2 == 1 and i < (depths[1] - 1):
ls_configs.append({
'comb_id': [1, 0],
'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
'stitch_layer_ids': [i // (depths[1] // depths[0])]
})
else:
if i < (depths[1] - 1):
ls_configs.append({
'comb_id': [1, 0],
'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
'stitch_layer_ids': [i // (depths[1] // depths[0])]
})
# large --> small --> large
for ls_cfg in ls_configs:
for sl_cfg in sl_configs:
if sl_cfg['stitch_layer_ids'][0] == depths[0] - 1:
continue
if sl_cfg['stitch_cfgs'][0][0] >= ls_cfg['stitch_cfgs'][0][1]:
lsl_confgs.append({
'comb_id': [1, 0, 1],
'stitch_cfgs': [ls_cfg['stitch_cfgs'][0], sl_cfg['stitch_cfgs'][0]],
'stitch_layer_ids': ls_cfg['stitch_layer_ids'] + sl_cfg['stitch_layer_ids']
})
# small --> large --> small
sls_configs = []
for sl_cfg in sl_configs:
for ls_cfg in ls_configs:
if ls_cfg['stitch_cfgs'][0][0] >= sl_cfg['stitch_cfgs'][0][1]:
sls_configs.append({
'comb_id': [0, 1, 0],
'stitch_cfgs': [sl_cfg['stitch_cfgs'][0], ls_cfg['stitch_cfgs'][0]],
'stitch_layer_ids': sl_cfg['stitch_layer_ids'] + ls_cfg['stitch_layer_ids']
})
total_configs += sl_configs + ls_configs + lsl_confgs + sls_configs
anchor_ids = []
sl_ids = []
ls_ids = []
lsl_ids = []
sls_ids = []
for i, cfg in enumerate(total_configs):
comb_id = cfg['comb_id']
if len(comb_id) == 1:
anchor_ids.append(i)
continue
if len(comb_id) == 2:
route = []
front, end = cfg['stitch_cfgs'][0]
route.append([0, front])
route.append([end, depths[comb_id[-1]]])
cfg['route'] = route
if comb_id == [0, 1] and front != 11:
sl_ids.append(i)
elif comb_id == [1, 0]:
ls_ids.append(i)
if len(comb_id) == 3:
route = []
front_1, end_1 = cfg['stitch_cfgs'][0]
front_2, end_2 = cfg['stitch_cfgs'][1]
route.append([0, front_1])
route.append([end_1, front_2])
route.append([end_2, depths[comb_id[-1]]])
cfg['route'] = route
if comb_id == [1, 0, 1]:
lsl_ids.append(i)
elif comb_id == [0, 1, 0]:
sls_ids.append(i)
cfg['stitch_layer_ids'].append(-1)
model_combos = [(0, 1), (1, 0)]
return total_configs, model_combos, [len(sl_configs), len(ls_configs)], anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids
def format_out_features(outs, with_cls_token, hw_shape):
B, _, C = outs[0].shape
for i in range(len(outs)):
if with_cls_token:
# Remove class token and reshape token for decoder head
outs[i] = outs[i][:, 1:].reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
else:
outs[i] = outs[i].reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
return outs
class LoRALayer():
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class Linear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
class StitchingLayer(nn.Module):
def __init__(self, in_features=None, out_features=None, r=0):
super().__init__()
self.transform = Linear(in_features, out_features, r=r)
def init_stitch_weights_bias(self, weight, bias):
self.transform.weight.data.copy_(weight)
self.transform.bias.data.copy_(bias)
def forward(self, x):
out = self.transform(x)
return out
class SNNet(nn.Module):
def __init__(self, anchors=None):
super(SNNet, self).__init__()
self.anchors = nn.ModuleList(anchors)
self.depths = [len(anc.blocks) for anc in self.anchors]
total_configs, num_stitches = get_stitch_configs_general_unequal(self.depths)
self.stitch_layers = nn.ModuleList(
[StitchingLayer(self.anchors[0].embed_dim, self.anchors[1].embed_dim) for _ in range(num_stitches)])
self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
self.all_cfgs = list(self.stitch_configs.keys())
self.num_configs = len(self.all_cfgs)
self.stitch_config_id = 0
self.is_ranking = False
def reset_stitch_id(self, stitch_config_id):
self.stitch_config_id = stitch_config_id
def initialize_stitching_weights(self, x):
logger = get_root_logger()
front, end = 0, 1
with torch.no_grad():
front_features = self.anchors[front].extract_block_features(x)
end_features = self.anchors[end].extract_block_features(x)
for i, blk_id in enumerate(range(self.depths[0])):
front_id, end_id = i, (i + 1) * (self.depths[1] // self.depths[0])
front_blk_feat = front_features[front_id]
end_blk_feat = end_features[end_id - 1]
w, b = ps_inv(front_blk_feat, end_blk_feat)
self.stitch_layers[i].init_stitch_weights_bias(w, b)
logger.info(f'Initialized Stitching Model {front} to Model {end}, Layer {i}')
def init_weights(self):
for anc in self.anchors:
anc.init_weights()
def sampling_stitch_config(self):
self.stitch_config_id = np.random.choice(self.all_cfgs)
def forward(self, x):
stitch_cfg_id = self.stitch_config_id
comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']
if len(comb_id) == 1:
return self.anchors[comb_id[0]](x)
cfg = self.stitch_configs[stitch_cfg_id]['stitch_cfgs']
x = self.anchors[comb_id[0]].forward_until(x, blk_id=cfg[0])
x = self.stitch_layers[cfg[0]](x)
x = self.anchors[comb_id[1]].forward_from(x, blk_id=cfg[1])
return x
class SNNetv2(nn.Module):
def __init__(self, anchors=None, include_sl=True, include_ls=True, include_lsl=True, include_sls=True, lora_r=0):
super(SNNetv2, self).__init__()
self.anchors = nn.ModuleList(anchors)
self.lora_r = lora_r
self.depths = [len(anc.blocks) for anc in self.anchors]
total_configs, model_combos, num_stitches, anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids = get_stitch_configs_bidirection(self.depths)
self.stitch_layers = nn.ModuleList()
self.stitching_map_id = {}
for i, (comb, num_sth) in enumerate(zip(model_combos, num_stitches)):
front, end = comb
temp = nn.ModuleList(
[StitchingLayer(self.anchors[front].embed_dim, self.anchors[end].embed_dim, r=lora_r) for _ in range(num_sth)])
temp.append(nn.Identity())
self.stitch_layers.append(temp)
self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
self.stitch_init_configs = {i: cfg for i, cfg in enumerate(total_configs) if len(cfg['comb_id']) == 2}
self.all_cfgs = list(self.stitch_configs.keys())
logger = get_root_logger()
logger.info(str(self.all_cfgs))
self.all_cfgs = anchor_ids
if include_sl:
self.all_cfgs += sl_ids
if include_ls:
self.all_cfgs += ls_ids
if include_lsl:
self.all_cfgs += lsl_ids
if include_sls:
self.all_cfgs += sls_ids
self.num_configs = len(self.stitch_configs)
self.stitch_config_id = 0
def reset_stitch_id(self, stitch_config_id):
self.stitch_config_id = stitch_config_id
def set_ranking_mode(self, ranking_mode):
self.is_ranking = ranking_mode
def initialize_stitching_weights(self, x):
logger = get_root_logger()
anchor_features = []
for anchor in self.anchors:
with torch.no_grad():
temp = anchor.extract_block_features(x)
anchor_features.append(temp)
for idx, cfg in self.stitch_init_configs.items():
comb_id = cfg['comb_id']
if len(comb_id) == 2:
front_id, end_id = cfg['stitch_cfgs'][0]
stitch_layer_id = cfg['stitch_layer_ids'][0]
front_blk_feat = anchor_features[comb_id[0]][front_id]
end_blk_feat = anchor_features[comb_id[1]][end_id - 1]
w, b = ps_inv(front_blk_feat, end_blk_feat)
self.stitch_layers[comb_id[0]][stitch_layer_id].init_stitch_weights_bias(w, b)
logger.info(f'Initialized Stitching Layer {cfg}')
def init_weights(self):
for anc in self.anchors:
anc.init_weights()
def sampling_stitch_config(self):
flops_id = np.random.choice(len(self.flops_grouped_cfgs), p=self.flops_sampling_probs)
stitch_config_id = np.random.choice(self.flops_grouped_cfgs[flops_id])
return stitch_config_id
def forward(self, x):
if self.training:
stitch_cfg_id = self.sampling_stitch_config()
else:
stitch_cfg_id = self.stitch_config_id
comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']
# forward by a single anchor
if len(comb_id) == 1:
return self.anchors[comb_id[0]](x)
# forward among anchors
route = self.stitch_configs[stitch_cfg_id]['route']
stitch_layer_ids = self.stitch_configs[stitch_cfg_id]['stitch_layer_ids']
# patch embeding
x = self.anchors[comb_id[0]].forward_patch_embed(x)
for i, (model_id, cfg) in enumerate(zip(comb_id, route)):
x = self.anchors[model_id].selective_forward(x, cfg[0], cfg[1])
x = self.stitch_layers[model_id][stitch_layer_ids[i]](x)
x = self.anchors[comb_id[-1]].forward_norm_head(x)
return x