OpenJMLA / modeling_maelm.py
sino
Upload modeling_maelm.py
a62a2b1
raw
history blame
22.6 kB
import json
import os
import pdb
from mmcv.cnn.bricks import padding
import torch
from torch import nn, einsum
from typing import Optional, Dict, Tuple
from src.mae_vit import MAEViT
from src.htsat import HTSAT_Swin_Transformer, create_htsat_model
from src.LMdecoder import LMDecoder, LMDecoder_qlora
from src.vision_transformer import VisionTransformer
from einops import rearrange, repeat
from einops_exts import rearrange_many
import inspect
from transformers.modeling_utils import PreTrainedModel
from .configuration_maelm import MAELMConfig
class ArgsHandler:
def __init__(self, module, funcname, fargs, fkargs):
self.fargs = list(fargs)
self.fkargs = fkargs
func = getattr(module, funcname)
fal_repr = f"{funcname}_argnames_list"
if (argns_list:=getattr(module, fal_repr, None)) is None:
self.func_sig = inspect.signature(func)
self.argnames_list = list(self.func_sig.parameters.keys())
setattr(module, fal_repr, self.argnames_list)
else:
self.argnames_list = argns_list
def get_arg(self, arg_name):
if arg_name in self.fkargs:
arg = self.fkargs[arg_name]
else:
arg = self.fargs[self.argnames_list.index(arg_name)]
return arg
def set_arg(self, arg_name, arg_value):
if arg_name in self.fkargs:
self.fkargs[arg_name] = arg_value
else:
self.fargs[self.argnames_list.index(arg_name)] = arg_value
def return_all_args(self,):
return tuple(self.fargs), self.fkargs
class SquaredReLU(nn.Module):
""" squared ReLU activation function"""
def __init__(self):
super().__init__()
def forward(self, x):
return torch.pow(torch.relu(x), 2)
def FeedForward(dim, out_dim, mult=4, act='gelu'):
"""
lucidrains implementation, slightly modified with the act parameter.
"""
acts = dict(
gelu=nn.GELU,
sqrelu=SquaredReLU,
relu=nn.ReLU
)
assert act in acts, f"act. can only be one of {acts.keys()}"
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
acts[act](),
nn.Linear(inner_dim, out_dim, bias=False)
)
class PerceiverAttentionLayer(nn.Module):
def __init__(
self,
*,
feat_dim,
latent_dim,
dim_head=64,
heads=8
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads
# trainable components of PerceiverAttentionLayer
self.norm_media = nn.LayerNorm(feat_dim)
self.norm_latents = nn.LayerNorm(latent_dim)
self.to_q = nn.Linear(latent_dim, inner_dim, bias=False)
self.to_k = nn.Linear(feat_dim, inner_dim, bias=False)
self.to_v = nn.Linear(feat_dim, inner_dim, bias=False)
self.to_out = nn.Linear(inner_dim, latent_dim, bias=False)
def forward(self, features, latents):
"""
Latent vectors are cross-attending to the visual features x.
:param x: Tensor (n_batch, n_features, dim)
visual features
:param latents: Tensor (n_batch, n_latents, dim)
latent learnt vectors from which the queries are computed.
Actually the same, just replicated in n_batch and n_frames dimension.
:return: Tensor (n_batch, n_latents, dim)
"""
assert features.ndim == 3
assert latents.ndim == 3
assert features.shape[0] == latents.shape[0]
#assert features.shape[2] == latents.shape[2]
n_heads = self.heads
n_batch, n_features, dim = features.shape
n_queries = latents.shape[1]
# layer normalization, as usual
x = self.norm_media(features)
latents = self.norm_latents(latents)
# queries
# compute the queries from the latents, for all attention heads simultaneously.
q = self.to_q(latents)
q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads)
assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head])
# keys and values for all attention heads
'''
kv_input = torch.cat((x, latents), dim=-2)
n_features_latents = n_features + n_queries
'''
kv_input = x
n_features_latents = n_features
# keys, values
k = self.to_k(kv_input)
v = self.to_v(kv_input)
# batch, features, (heads, dim)
# split so we have an extra dimension for the heads
# q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h=h)
k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads)
assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head])
# scale queries?
q = q * self.scale
# attention
# attention scores
# sim = einsum('... i d, ... j d -> ... i j', q, k)
sim = einsum('b h q d, b h f d -> b h q f', q, k)
# Is this for numerical stability? Does not affect the result of the softmax operation
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
alphas = sim.softmax(dim=-1)
# out = einsum('... i j, ... j d -> ... i d', alphas, v)
out = einsum('b h q f, b h f v -> b h q v', alphas, v)
# out = rearrange(out, 'b h t n d -> b t n (h d)', h=h)
out = rearrange(out, 'b h q v -> b q (h v)')
return self.to_out(out)
class MAEForCausalLM(PreTrainedModel):
"""
Args:
backbone (dict): Config dict for encoder. Defaults to None.
neck (dict): Config dict for encoder. Defaults to None.
head (dict): Config dict for loss functions. Defaults to None.
init_cfg (dict, optional): Config dict for weight initialization.
Defaults to None.
"""
config_class = MAELMConfig
def __init__(self, config: MAELMConfig) -> None:
super().__init__(config)
backbone = config.backbone
assert backbone is not None
bk_name = backbone.pop('name')
self.bk_name = bk_name
if bk_name == 'MAEViT':
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
self.backbone = MAEViT(**backbone)
if ckpt_path is not None:
ckpt = torch.load( ckpt_path,'cpu')
self.backbone.load_state_dict(ckpt['state_dict'])
elif bk_name == 'HTSAT':
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
self.backbone = create_htsat_model(backbone)
if ckpt_path is not None:
ckpt = torch.load( ckpt_path,'cpu')
self.backbone.load_state_dict(ckpt['state_dict'])
elif bk_name == 'qformer':
raise NotImplemented
else:
raise NotImplemented
# neck["num_patches"] = self.backbone.num_patches
# neck["patch_resolution"] = self.backbone.patch_resolution
neck = config.neck
assert neck is not None
nk_name = neck.pop('name')
if nk_name == 'LMDecoder':
self.neck = LMDecoder(**neck)
elif nk_name == 'LMDecoder_qlora':
self.neck = LMDecoder_qlora(**neck)
else:
raise NotImplemented
self.config = self.neck.LMconfig # TODO
'''
self.ae_proj = nn.Linear(
768, self.config.hidden_size
)
'''
## TODO
#self.neck.lm.apply(lambda m:m.gradient_checkpointing=True)
self.neck.lm.model.gradient_checkpointing = False
self.register_buffer('ones', torch.ones((1,4096), dtype=torch.long), persistent=False)
self.graft_adapter()
self.init_weights()
# float32 --> bfloat16
for p in self.parameters():
p.data = p.data.to(torch.bfloat16)
if config.resume_from_checkpoint is not None:
drain_loader = True
accelerator.load_state(config.resume_from_checkpoint, load_module_strict=False)
# start_epoch, start_step, all_step = [int(_.split('_')[1]) for _ in args.resume_from_checkpoint.split('/')[-2].split('-')]
elif config.resume_from_pth is not None:
print(f'###########loading##########{config.resume_from_pth}###########loading##########')
ckpt = torch.load(config.resume_from_pth, map_location='cpu')
ckpt_copy = {k[7:]: v for k, v in ckpt.items()}
self.load_state_dict(ckpt_copy, strict=False)
print(f'###########loaded##########{config.resume_from_pth}###########loaded##########')
if False:
self.patch_llm()
self.first_run = True
def graft_adapter(self):
adapter_latent_len = 32
self.adapter_latent_len = adapter_latent_len
self.adapter_latent = nn.Parameter(torch.rand((1,adapter_latent_len, self.config.hidden_size), \
dtype=torch.float))
resampler_latent_len = 32
self.resampler_latent_len = resampler_latent_len
self.resampler_latent = nn.Parameter(torch.rand((1,resampler_latent_len, self.config.hidden_size), \
dtype=torch.float))
## TODO
# self.adapter.pre_bn = torch.nn.BatchNorm1d(4096, affine=True)
self.adapter = nn.ModuleList([])
ff_mult = 4
heads=8
dim_head=512
act='gelu'
lm_dim = self.config.hidden_size
if self.bk_name == 'HTSAT':
feat_dim = 1024
depth = len(self.backbone.layers[2].blocks)
else:
feat_dim = 768
depth = int(len(self.neck.lm.model.layers)/2) # 16
for idx in range(depth):
self.adapter.append(nn.ModuleList([
Adapter(input_size=self.config.hidden_size),
# PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=dim_head, heads=heads),
# FeedForward(dim=lm_dim, out_dim=lm_dim, mult=1, act=act),
#FeedForward(dim=self.dim, out_dim=768, mult=ff_mult, act=act) if idx != depth-1 else nn.Identity()
]))
self.samplers = nn.ModuleList([]) # add
for _ in range(3):
self.samplers.append(nn.ModuleList([
PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=64, heads=heads),
FeedForward(dim=lm_dim, out_dim=lm_dim, mult=4),
]))
self.norm = nn.LayerNorm(lm_dim)
# self.agate_list = nn.ParameterList([])
# for i in range(len(self.neck.lm.model.layers)):
# self.agate_list.append(nn.Parameter(torch.zeros(lm_dim)))
def init_weights(self):
try:
super().init_weights()
except:
pass
# import traceback
# traceback.print_exc()
if getattr(self, 'adapter_latent', None) is not None:
self.adapter_latent.data.normal_(mean=0.0, std=0.02)
if getattr(self, 'resampler_latent', None) is not None:
self.adapter_latent.data.normal_(mean=0.0, std=0.02)
def forward_resampler(self, x):
# b, 768, 512
latents = repeat(self.resampler_latent, 'b n d -> (bs b) n d', bs=x.shape[0])
for attn, ff in self.samplers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
v2t_feats = self.norm(latents) #
# v2t_atts = torch.ones(v2t_feats.shape[:2], dtype=torch.long, device=v2t_feats.device)
return v2t_feats # bs, 32, dim_llm
def hook_adapter(self, audio_embedding, lm, v2t_feats):
class PHooker:
# model = self.backbone
# mgtr = self.backbone.forward_generator(spectrogram)
adapter = self.adapter
y = v2t_feats
handles_list = list()
cnter = 0
def layer_prehook(self, m, margs, mkargs):
ahl = ArgsHandler(m, 'forward', margs, mkargs)
# print(self.cnter)
# if self.cnter>=16:
# self.cnter+=1
# return None
adapt = self.adapter[self.cnter][0]
hs = ahl.get_arg("hidden_states")
adapter_residual = hs
neo_hs = adapt(hs, adapter_residual)
self.cnter+=1
ahl.set_arg("hidden_states", neo_hs)
return ahl.return_all_args()
def first_layer_prehook(self, m, margs, mkargs):
ahl = ArgsHandler(m, 'forward', margs, mkargs)
neo_lm_latents = self.y # torch.Size([128, 32, 4096])
hs = ahl.get_arg("hidden_states") # torch.Size([128, 87, 4096])
hs_msk = self.lm_ahl.get_arg("input_ids") < 0 # torch.Size([128, 87]) [False,, True*32, False,,]
# __import__('pdb').set_trace()
neo_hs = hs.masked_scatter(hs_msk.unsqueeze(-1), neo_lm_latents) # resampler hooker直接替换
ahl.set_arg("hidden_states", neo_hs)
return ahl.return_all_args()
def lm_prehook(self, m, margs, mkargs):
self.lm_ahl = ArgsHandler(m, 'forward', margs, mkargs)
return None
def last_layer_hook(self, m, margs, mkargs):
# __import__('pdb').set_trace()
self.cnter = 0
if getattr(lm,'phooker',False):
for _ in lm.phooker.handles_list:
_.remove()
del lm.phooker
lm.phooker = None
phooker = PHooker()
phooker.handles_list.append(lm.register_forward_pre_hook(phooker.lm_prehook, with_kwargs=True))
# 第一层插入
phooker.handles_list.append(lm.model.layers[0].register_forward_pre_hook(phooker.first_layer_prehook, with_kwargs=True))
for ii in range(1,len(lm.model.layers),2):
l = lm.model.layers[ii]
handle = l.register_forward_pre_hook(phooker.layer_prehook, with_kwargs=True)
phooker.handles_list.append(handle)
phooker.handles_list.append(lm.model.layers[-1].register_forward_pre_hook(phooker.last_layer_hook, with_kwargs=True))
lm.phooker = phooker
return None
def prepare_ids(self, batch, audio_ids):
toker = self.neck.tokenizer
# for idx, l in enumerate(self.neck.lm.model.layers):
# l.agate = self.agate_list[idx].clone() ## should clone the parameter
with torch.no_grad():
input_ids = batch['input_ids']
att_msk = batch['attention_mask']
au_crds = batch['audio_crds']
ans_crds = batch['ans_crds']
bsz = input_ids.shape[0]
# __import__('pdb').set_trace()
## TODO
merged_ids, merged_msk, label_ids = list(), list(), list()
for i in range(bsz):
# cur_merged_ids = torch.cat([input_ids[i,:au_crds[i]], -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]])
cur_merged_ids = torch.cat([ -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]])
# cur_au_msk = self.ones[:,:audio_ids.shape[1]][0].clone().type_as(att_msk).detach()
cur_au_msk = torch.ones(audio_ids.shape[1], device=audio_ids.device)
# cur_merged_msk = torch.cat([att_msk[i,:au_crds[i]], cur_au_msk, att_msk[i,au_crds[i]:]])
cur_merged_msk = torch.cat([ cur_au_msk, att_msk[i,au_crds[i]:]])
cur_label_ids = cur_merged_ids.clone().detach()
cur_label_ids[:audio_ids.shape[1]+ans_crds[i]] = -100
merged_ids.append(cur_merged_ids)
merged_msk.append(cur_merged_msk)
label_ids.append(cur_label_ids)
merged_ids = torch.stack(merged_ids, dim=0)
merged_msk = torch.stack(merged_msk, dim=0)
label_ids = torch.stack(label_ids, dim=0)
assert merged_ids.shape[0] == bsz
assert merged_ids.shape == merged_msk.shape
label_msk = merged_msk.clone()
assert label_msk.shape == merged_msk.shape
assert merged_msk[:,-1].max() == 1
for i in range(len(ans_crds)):
label_ids[i,:audio_ids.shape[1]+ans_crds[i]].fill_(-100)
merged_labels = label_ids
merged_ids[merged_ids.eq(-100)] = toker.pad_token_id
return merged_ids, merged_msk, merged_labels
def forward(self, batch, **kwargs):
"""Forward computation during training.
Args:
img (torch.Tensor): Input images of shape (N, C, H, W).
kwargs: Any keyword arguments to be used to forward.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
bsz = len(batch['input_ids'])
device = batch['input_ids'].device
float_type = next(self.parameters()).dtype
spectrogram = batch['spectrogram'].type(float_type)
audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512
resampler_feats = self.forward_resampler(audio_embedding)
self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook
# self.hook_resapmler(resampler_feats, self.neck.lm)
audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device)
assert audio_ids.max() < 100
merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids)
try:
assert merged_ids.shape == merged_labels.shape
outs = self.neck(input_ids=merged_ids.contiguous().long(),
flatten_embs=self.adapter_latent.flatten(0,1), # 32, 4096
# flatten_embs = resampler_feats.flatten(0,1), # b, 32, 4096
attention_mask=merged_msk.contiguous().long(),
labels=merged_labels.contiguous().long(), use_cache=False)
except Exception as e:
import traceback
traceback.print_exc()
__import__('remote_pdb').set_trace()
#outs.hidden_logits = self.hidden_logits
## TODO
if eval(os.environ.get("doing_eval", 'False')):
outs.merged_ids = merged_ids.cpu()
outs.merged_labels = merged_labels.cpu()
return outs
def forward_test(self, batch, **kwargs):
"""Forward computation during training.
Args:
img (torch.Tensor): Input images of shape (N, C, H, W).
kwargs: Any keyword arguments to be used to forward.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
bsz = len(batch['input_ids'])
device = batch['input_ids'].device
float_type = next(self.parameters()).dtype
spectrogram = batch['spectrogram'].type(float_type)
audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512
resampler_feats = self.forward_resampler(audio_embedding)
self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook
# self.extract_features(batch, self.neck.lm)
audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device)
assert audio_ids.max() < 100
merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids)
au_crds = batch['audio_crds']
ans_crds = batch['ans_crds']
aid_len = audio_ids.shape[-1]
toker = self.neck.tokenizer
with torch.no_grad():
## TODO
pad_token = toker.encode(self.neck.tokenizer.eos_token)[0]
padded_merged_ids = self.ones[:, :aid_len+max(ans_crds)].repeat(bsz, 1).clone().detach() * pad_token
for i in range(bsz):
# for i in range(1):
assert au_crds[i] <= ans_crds[i]
cur_ids = merged_ids[i][:aid_len+ans_crds[i]]
padded_merged_ids[i][max(ans_crds)-ans_crds[i]:] = cur_ids
# __import__('pdb').set_trace()
outs = self.neck.generate(padded_merged_ids, self.adapter_latent.flatten(0,1))
#outs.hidden_logits = self.hidden_logits
return outs
import torch
from torch import nn
from transformers.activations import ACT2FN
class Adapter(nn.Module):
"""
Implementation of a sequential bottleneck adapter block.
"""
def __init__(
self,
input_size,
down_sample=None,
):
super().__init__()
self.input_size = input_size
# if a downsample size is not passed, we just half the size of the original input
self.down_sample = down_sample
if down_sample is None:
self.down_sample = self.input_size // 2
self.adapter_norm_before = nn.LayerNorm(self.input_size)
self.adapter_down = nn.Linear(self.input_size, self.down_sample)
self.non_linearity = ACT2FN["silu"]
# Up projection to input size
self.adapter_up = nn.Linear(self.down_sample, self.input_size)
# Additional scaling factor (from He et al. (2021))
self.scaling = nn.Parameter(torch.ones(1))
self.adapter_down.apply(self._init_weights)
self.adapter_up.apply(self._init_weights)
def forward(self, x, residual_input): # , residual_input=None):
down = self.non_linearity(self.adapter_down(self.adapter_norm_before(x)))
up = self.adapter_up(down)
up = up * self.scaling
output = up
output = output + residual_input
return output
@staticmethod
def _init_weights(module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)):
# std defaults to 0.02, this might need to be changed
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()