ViAVSP-LLM_v1.0 / modelling.py
tanthinhdt's picture
Upload model
3e1357a verified
raw
history blame
27.5 kB
import torch
import logging
import contextlib
import numpy as np
import torch.nn as nn
from .resnet import ResNetEncoder
from .utils import compute_mask_indices
from .encoder import TransformerEncoder
from .configuration import AVHubertConfig, AVSPLLMConfig
from typing import Optional, Tuple, List, Dict, Any
from peft import get_peft_model, LoraConfig
from fairseq.modules import GradMultiply, LayerNorm
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import (
FeatureExtractionMixin,
PreTrainedModel,
BitsAndBytesConfig,
AutoModelForCausalLM,
GenerationConfig,
)
class AVHubertFeatureExtractor(FeatureExtractionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class AVSPLLMFeatureExtractor(AVHubertFeatureExtractor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class AVHubertVideoFeatureEncoder(nn.Module):
def __init__(self, config: AVHubertConfig) -> None:
super().__init__()
self.resnet = ResNetEncoder(relu_type=config.resnet_relu_type)
self.proj = nn.Linear(self.resnet.backend_out, config.encoder_embed_dim)
self.encoder = (
TransformerEncoder(config)
if config.sub_encoder_layers > 0
else None
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.resnet(x)
x = self.proj(x.transpose(1, 2))
if self.encoder is not None:
x = self.encoder(x)[0].transpose(1, 2)
else:
x = x.transpose(1, 2)
return x
class AVHubertAudioFeatureEncoder(nn.Module):
def __init__(self, config: AVHubertConfig) -> None:
super().__init__()
self.proj = nn.Linear(config.audio_feat_dim, config.encoder_embed_dim)
self.encoder = (
TransformerEncoder(config)
if config.sub_encoder_layers > 0
else None
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x.transpose(1, 2))
if self.encoder is not None:
x = self.encoder(x)[0].transpose(1, 2)
else:
x = x.transpose(1, 2)
return x
class AVHubertModel(PreTrainedModel):
config_class = AVHubertConfig
def __init__(
self,
config: AVHubertConfig = AVHubertConfig(),
dictionaries: List = [None],
) -> None:
super().__init__(config=config)
label_rate = config.label_rate
feature_ds_rate = config.feature_ds_rate
sample_rate = config.sample_rate
self.feat2tar_ration = label_rate * feature_ds_rate / sample_rate
self.feature_extractor_video = AVHubertVideoFeatureEncoder(config)
self.feature_extractor_audio = AVHubertAudioFeatureEncoder(config)
if config.modality_fuse == "concat":
self.encoder_embed_dim = config.encoder_embed_dim * 2
elif config.modality_fuse == "add":
self.encoder_embed_dim = config.encoder_embed_dim
self.post_extract_proj = (
nn.Linear(self.encoder_embed_dim, config.encoder_embed_dim)
if self.encoder_embed_dim != config.encoder_embed_dim
else None
)
self.dropout_input = nn.Dropout(config.dropout_input)
self.dropout_features = nn.Dropout(config.dropout_features)
if self.config.final_dim > 0:
final_dim = config.final_dim
else:
final_dim = config.encoder_embed_dim
self.mask_emb = nn.Parameter(
torch.FloatTensor(config.audio_feat_dim).uniform_()
if config.masking_type == "input"
else torch.FloatTensor(config.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(self.config)
self.layer_norm = LayerNorm(self.encoder_embed_dim)
self.target_glu = None
if config.target_glu:
self.target_glu = nn.Sequential(
nn.Linear(config.final_dim, config.final_dim * 2),
nn.GLU(),
)
if config.untie_final_proj:
self.final_proj = nn.Linear(
config.encoder_embed_dim,
final_dim * len(dictionaries),
)
else:
self.final_proj = nn.Linear(config.encoder_embed_dim, final_dim)
# modules below are not needed during fine-tuning
if any([d is None for d in dictionaries]):
self.num_classes = config.num_classes
else:
self.num_classes = sum([len(d) for d in dictionaries])
self.label_embs_concat = nn.Parameter(
torch.FloatTensor(self.num_classes, final_dim)
)
nn.init.uniform_(self.label_embs_concat)
def apply_input_mask(
self,
x: torch.Tensor,
padding_mask: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
B, C, T = x.shape[:3]
is_audio = True if len(x.shape) == 3 else False
if is_audio:
mask_prob = self.config.mask_prob_audio
mask_length = self.config.mask_length_audio
else:
mask_prob = self.config.mask_prob_image
mask_length = self.config.mask_length_image
if mask_prob > 0:
mask_indices, starts, ends, batch_indexes = compute_mask_indices(
(B, T),
padding_mask,
mask_prob,
mask_length,
self.config.mask_selection,
self.config.mask_other,
min_masks=2,
no_overlap=self.config.no_mask_overlap,
min_space=self.config.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x = x.transpose(1, 2).contiguous() # [B, T, C, H, W]
if B == 1:
x[mask_indices] = 0
elif is_audio:
x[mask_indices] = self.mask_emb
elif self.config.selection_type == "same_other_seq":
perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
x_perm = x[perm]
x[mask_indices] = x_perm[mask_indices]
elif self.config.selection_type == "same_seq":
batch_indexes_, other_indexes = [], []
for batch_index, start, end in zip(batch_indexes, starts, ends):
length = end - start
other_start = np.setdiff1d(
np.arange(T), np.arange(max(0, start - length), end)
)
if len(other_start) > 0:
other_start = np.random.choice(other_start, size=1)
else:
other_start = 0
other_end = other_start + length
other_indexes.append(
np.arange(other_start, other_end).clip(max=T - 1)
)
batch_indexes_.append(
np.zeros([length], dtype=np.int64) + batch_index
)
batch_indexes = np.concatenate(batch_indexes_)
other_indexes = np.concatenate(other_indexes)
x[mask_indices] = x[batch_indexes, other_indexes]
x = x.transpose(1, 2).contiguous()
else:
mask_indices = None
if self.config.mask_channel_prob > 0:
logging.info("No mask channel prob for input masking")
return x, mask_indices
def apply_feature_mask(
self,
x: torch.Tensor,
padding_mask: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, C = x.shape
assert all((
self.config.mask_prob_audio == self.config.mask_prob_image,
self.config.mask_length_audio == self.config.mask_length_image,
)), "masking prob/length for image/audio be same for feature masking"
mask_prob = self.config.mask_prob_audio
mask_length = self.config.mask_length_image
if mask_prob > 0:
mask_indices, _, _, _ = compute_mask_indices(
(B, T),
padding_mask,
mask_prob,
mask_length,
self.config.mask_selection,
self.config.mask_other,
min_masks=2,
no_overlap=self.config.no_mask_overlap,
min_space=self.config.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.config.mask_channel_prob > 0:
mask_channel_indices, _, _, _ = compute_mask_indices(
(B, C),
None,
self.config.mask_channel_prob,
self.config.mask_channel_length,
self.config.mask_channel_selection,
self.config.mask_channel_other,
no_overlap=self.config.no_mask_channel_overlap,
min_space=self.config.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_features(
self,
source: Dict[str, torch.Tensor],
modality: str,
) -> torch.Tensor:
extractor = eval(f"self.feature_extractor_{modality}")
if self.config.feature_grad_mult > 0:
features = extractor(source)
if self.config.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.config.feature_grad_mult)
else:
with torch.no_grad():
features = extractor(source)
return features
def forward_targets(
self,
features: torch.Tensor,
mask_indices: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
if mask_indices is not None:
mask_indices = mask_indices[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_list = [t[:, target_inds.long()] for t in target_list]
return features, mask_indices, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def compute_logits(self, feats: torch.Tensor, emb_mat: torch.Tensor) -> torch.Tensor:
# feats: [B, T, F], emb_mat: [V, F]
if self.config.sim_type == "dot":
logits = torch.matmul(feats, emb_mat.transpose(0, 1))
elif self.config.sim_type == "cosine":
batch_size, timesteps, emb_dim = feats.size()
feats_ = feats.view(-1, emb_dim)
# [B*T, V]
nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1)
# [B*T, V]
denom = (
(feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1)
* (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0)
)
logits = (nom / denom.clamp(min=1e-6)).view(batch_size, timesteps, -1)
else:
raise NotImplementedError
logits = logits / self.config.logit_temp
return logits
def forward(
self,
source: Dict[str, torch.Tensor],
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""output layer is 1-based"""
src_audio, src_video = source["audio"], source["video"]
if mask and self.masking_type == "input":
src_video, mask_indices_video = self.apply_input_mask(
src_video, padding_mask, target_list
)
src_audio, mask_indices_audio = self.apply_input_mask(
src_audio, padding_mask, target_list
)
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video)
else:
src_audio, src_video, mask_indices = src_audio, src_video, None
# [B, F, T]
features_audio = self.forward_features(src_audio, modality="audio")
features_video = self.forward_features(src_video, modality="video")
if self.training:
modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random()
if modality_drop_prob < self.config.modality_dropout:
if audio_drop_prob < self.config.audio_dropout:
features_audio = 0 * features_audio
else:
features_video = 0 * features_video
if self.config.modality_fuse == "concat":
features = torch.cat([features_audio, features_video], dim=1)
elif self.config.modality_fuse == "add":
features = features_audio + features_video
if target_list is not None:
features, mask_indices, target_list = self.forward_targets(
features, mask_indices, target_list
)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
if self.config.masking_type == "feature" and mask:
x, mask_indices = self.apply_feature_mask(
features, padding_mask, target_list
)
else:
x = features
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
proj_x = self.final_proj(x)
if self.config.untie_final_proj:
proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1)
else:
proj_x_list = [proj_x for _ in self.num_classes]
# [[B*T, V]]
logit_list = [
self.compute_logits(proj, emb).view(-1, num_class)
for proj, emb, num_class in zip(
proj_x_list, label_embs_list, self.num_classes
)
]
mask = torch.logical_and(mask_indices, ~padding_mask).view(-1)
unmask = torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T]
logit_m_list = [logit[mask] for logit in logit_list]
logit_u_list = [logit[unmask] for logit in logit_list]
target_m_list = [target.view(-1)[mask].long() for target in target_list]
target_u_list = [target.view(-1)[unmask].long() for target in target_list]
return {
"logit_m_list": logit_m_list,
"logit_u_list": logit_u_list,
"target_m_list": target_m_list,
"target_u_list": target_u_list,
"padding_mask": padding_mask,
"features_pen": features_pen,
}
def extract_features(
self,
source: Dict[str, torch.Tensor],
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
return feature, res["padding_mask"]
def extract_units(
self,
source: Dict[str, torch.Tensor],
padding_mask: torch.Tensor = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=None,
)
feature = res["features"] if ret_conv else res["x"]
proj_x = self.final_proj(feature)
# B T
units = (
torch
.matmul(proj_x, self.label_embs_concat.transpose(0, 1))
.argmax(dim=-1)
)
return units
def extract_finetune(
self,
source: Dict[str, torch.Tensor],
padding_mask: torch.Tensor = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
src_audio, src_video = source["audio"], source["video"]
if mask and self.config.masking_type == "input":
src_video, _ = self.apply_input_mask(
src_video, padding_mask, target_list=None
)
src_audio, _ = self.apply_input_mask(
src_audio, padding_mask, target_list=None
)
else:
src_audio, src_video, _ = src_audio, src_video, None
# features: [B, F, T]
if src_audio is not None and src_video is None:
features_audio = self.forward_features(
src_audio, modality="audio"
)
features_video = features_audio.new_zeros(
features_audio.size(0),
self.encoder_embed_dim,
features_audio.size(-1)
)
elif src_audio is None and src_video is not None:
features_video = self.forward_features(src_video, modality="video")
features_audio = features_video.new_zeros(
features_video.size(0),
self.encoder_embed_dim,
features_video.size(-1)
)
elif src_audio is not None and src_video is not None:
features_video = self.forward_features(src_video, modality="video")
features_audio = self.forward_features(
src_audio, modality="audio"
)
if self.config.modality_fuse == "concat":
features = torch.cat([features_audio, features_video], dim=1)
elif self.config.modality_fuse == "add":
features = features_audio + features_video
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
features,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
return x, padding_mask
def get_extra_losses(
self,
net_output: Dict[str, torch.Tensor],
) -> Tuple[List[torch.Tensor], List[str]]:
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
return extra_losses, names
def remove_pretraining_modules(self) -> None:
self.target_glu = None
self.final_proj = None
def compute_nce(
self,
x: torch.Tensor,
pos: torch.Tensor,
negs: torch.Tensor,
) -> torch.Tensor:
neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0)
targets = torch.cat([pos, negs], dim=0)
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
logits /= self.config.logit_temp
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
return logits
class HubertEncoderWrapper(nn.Module):
def __init__(
self,
config: AVHubertConfig,
dictionaries: List = [None],
) -> None:
super().__init__()
self.w2v_model = AVHubertModel(config, dictionaries)
def forward(
self,
source: Dict[str, torch.Tensor],
padding_mask: torch.Tensor,
**kwargs,
) -> Dict[str, torch.Tensor]:
w2v_args = {
"source": source,
"padding_mask": padding_mask,
}
x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
return {
"encoder_out": x, # T x B x C
"encoder_padding_mask": padding_mask, # B x T
"padding_mask": padding_mask,
}
def reorder_encoder_out(
self,
encoder_out: Dict[str, torch.Tensor],
new_order: torch.Tensor,
) -> Dict[str, torch.Tensor]:
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
if encoder_out["padding_mask"] is not None:
encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select(
0, new_order
)
return encoder_out
class AVSPLLMModel(PreTrainedModel):
config_class = AVSPLLMConfig
def __init__(
self,
config: AVSPLLMConfig = AVSPLLMConfig(),
dictionaries: List = [None],
) -> None:
super().__init__(config=config)
self.encoder = HubertEncoderWrapper(config, dictionaries)
self.encoder.w2v_model.remove_pretraining_modules()
self.avfeat_to_llm = nn.Linear(
config.encoder_embed_dim, config.decoder_embed_dim
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
decoder_4bit = AutoModelForCausalLM.from_pretrained(
config.llm_ckpt_path,
quantization_config=bnb_config,
)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
self.decoder = get_peft_model(decoder_4bit, lora_config)
self.decoder.print_trainable_parameters()
def forward(
self,
source: Dict[str, torch.Tensor],
target_list: torch.Tensor,
padding_mask: torch.Tensor,
**kwargs,
) -> CausalLMOutputWithPast:
ft = self.config.freeze_finetune_updates <= kwargs.get("num_updates", -1)
with torch.no_grad() if not ft else contextlib.ExitStack():
output = self.encoder(source, padding_mask, **kwargs)
output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"])
cluster_counts = source["cluster_counts"][0] # tensor list
results_tensor = []
start_idx = 0
for clutser_num in cluster_counts:
end_idx = start_idx + clutser_num
slice = output["encoder_out"][:, start_idx:end_idx, :]
mean_tensor = torch.mean(slice, dim=1, keepdim=True)
results_tensor.append(mean_tensor)
start_idx = end_idx
assert cluster_counts.sum().item() == output["encoder_out"].size()[1], \
f"{cluster_counts.sum().item()} != {output['encoder_out'].size()[1]}"
reduced_enc_out = torch.cat(results_tensor, dim=1)
B, T, D = reduced_enc_out.size()
instruction = source["text"]
instruction_embedding = self.decoder.model.model.embed_tokens(instruction)
labels = target_list.clone()
labels_embedding = self.decoder.model.model.embed_tokens(labels)
llm_input = torch.cat(
(instruction_embedding, reduced_enc_out, labels_embedding), dim=1
)
llm_labels = labels.clone()
llm_labels[llm_labels == 0] = -100
_, instruction_embedding_t, _ = instruction_embedding.size()
target_ids = (
torch.full((B, T + instruction_embedding_t), -100).long().to(labels.device)
)
llm_labels = torch.cat((target_ids, llm_labels), dim=1)
return self.decoder(
inputs_embeds=llm_input, labels=llm_labels, return_dict=True
)
@torch.no_grad()
def generate(
self,
inputs: Optional[Dict[str, torch.Tensor]] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Any:
output = self.encoder(**inputs)
output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"])
cluster_counts = inputs["source"]["cluster_counts"][0] # tensor list
results_tensor = []
start_idx = 0
for clutser_num in cluster_counts:
end_idx = start_idx + clutser_num
slice = output["encoder_out"][:, start_idx:end_idx, :]
mean_tensor = torch.mean(slice, dim=1, keepdim=True)
results_tensor.append(mean_tensor)
start_idx = end_idx
assert cluster_counts.sum().item() == output["encoder_out"].size()[1]
reduced_enc_out = torch.cat(results_tensor, dim=1)
B, T, D = reduced_enc_out.size()
instruction = inputs["source"]["text"]
instruction_embedding = self.decoder.model.model.embed_tokens(instruction)
llm_input = torch.cat((instruction_embedding, reduced_enc_out), dim=1)
self.decoder.config.use_cache = True
return self.decoder.generate(
inputs_embeds=llm_input,
**generation_config,
**kwargs,
)