import torch import logging import contextlib import numpy as np import torch.nn as nn from .resnet import ResNetEncoder from .utils import compute_mask_indices from .encoder import TransformerEncoder from .configuration import AVHubertConfig, AVSPLLMConfig from typing import Optional, Tuple, List, Dict, Any from peft import get_peft_model, LoraConfig from fairseq.modules import GradMultiply, LayerNorm from transformers.modeling_outputs import CausalLMOutputWithPast from transformers import ( FeatureExtractionMixin, PreTrainedModel, BitsAndBytesConfig, AutoModelForCausalLM, GenerationConfig, ) class AVHubertFeatureExtractor(FeatureExtractionMixin): def __init__(self, **kwargs): super().__init__(**kwargs) class AVSPLLMFeatureExtractor(AVHubertFeatureExtractor): def __init__(self, **kwargs): super().__init__(**kwargs) class AVHubertVideoFeatureEncoder(nn.Module): def __init__(self, config: AVHubertConfig) -> None: super().__init__() self.resnet = ResNetEncoder(relu_type=config.resnet_relu_type) self.proj = nn.Linear(self.resnet.backend_out, config.encoder_embed_dim) self.encoder = ( TransformerEncoder(config) if config.sub_encoder_layers > 0 else None ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.resnet(x) x = self.proj(x.transpose(1, 2)) if self.encoder is not None: x = self.encoder(x)[0].transpose(1, 2) else: x = x.transpose(1, 2) return x class AVHubertAudioFeatureEncoder(nn.Module): def __init__(self, config: AVHubertConfig) -> None: super().__init__() self.proj = nn.Linear(config.audio_feat_dim, config.encoder_embed_dim) self.encoder = ( TransformerEncoder(config) if config.sub_encoder_layers > 0 else None ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x.transpose(1, 2)) if self.encoder is not None: x = self.encoder(x)[0].transpose(1, 2) else: x = x.transpose(1, 2) return x class AVHubertModel(PreTrainedModel): config_class = AVHubertConfig def __init__( self, config: AVHubertConfig = AVHubertConfig(), dictionaries: List = [None], ) -> None: super().__init__(config=config) label_rate = config.label_rate feature_ds_rate = config.feature_ds_rate sample_rate = config.sample_rate self.feat2tar_ration = label_rate * feature_ds_rate / sample_rate self.feature_extractor_video = AVHubertVideoFeatureEncoder(config) self.feature_extractor_audio = AVHubertAudioFeatureEncoder(config) if config.modality_fuse == "concat": self.encoder_embed_dim = config.encoder_embed_dim * 2 elif config.modality_fuse == "add": self.encoder_embed_dim = config.encoder_embed_dim self.post_extract_proj = ( nn.Linear(self.encoder_embed_dim, config.encoder_embed_dim) if self.encoder_embed_dim != config.encoder_embed_dim else None ) self.dropout_input = nn.Dropout(config.dropout_input) self.dropout_features = nn.Dropout(config.dropout_features) if self.config.final_dim > 0: final_dim = config.final_dim else: final_dim = config.encoder_embed_dim self.mask_emb = nn.Parameter( torch.FloatTensor(config.audio_feat_dim).uniform_() if config.masking_type == "input" else torch.FloatTensor(config.encoder_embed_dim).uniform_() ) self.encoder = TransformerEncoder(self.config) self.layer_norm = LayerNorm(self.encoder_embed_dim) self.target_glu = None if config.target_glu: self.target_glu = nn.Sequential( nn.Linear(config.final_dim, config.final_dim * 2), nn.GLU(), ) if config.untie_final_proj: self.final_proj = nn.Linear( config.encoder_embed_dim, final_dim * len(dictionaries), ) else: self.final_proj = nn.Linear(config.encoder_embed_dim, final_dim) # modules below are not needed during fine-tuning if any([d is None for d in dictionaries]): self.num_classes = config.num_classes else: self.num_classes = sum([len(d) for d in dictionaries]) self.label_embs_concat = nn.Parameter( torch.FloatTensor(self.num_classes, final_dim) ) nn.init.uniform_(self.label_embs_concat) def apply_input_mask( self, x: torch.Tensor, padding_mask: torch.Tensor, target_list: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: B, C, T = x.shape[:3] is_audio = True if len(x.shape) == 3 else False if is_audio: mask_prob = self.config.mask_prob_audio mask_length = self.config.mask_length_audio else: mask_prob = self.config.mask_prob_image mask_length = self.config.mask_length_image if mask_prob > 0: mask_indices, starts, ends, batch_indexes = compute_mask_indices( (B, T), padding_mask, mask_prob, mask_length, self.config.mask_selection, self.config.mask_other, min_masks=2, no_overlap=self.config.no_mask_overlap, min_space=self.config.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x = x.transpose(1, 2).contiguous() # [B, T, C, H, W] if B == 1: x[mask_indices] = 0 elif is_audio: x[mask_indices] = self.mask_emb elif self.config.selection_type == "same_other_seq": perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B x_perm = x[perm] x[mask_indices] = x_perm[mask_indices] elif self.config.selection_type == "same_seq": batch_indexes_, other_indexes = [], [] for batch_index, start, end in zip(batch_indexes, starts, ends): length = end - start other_start = np.setdiff1d( np.arange(T), np.arange(max(0, start - length), end) ) if len(other_start) > 0: other_start = np.random.choice(other_start, size=1) else: other_start = 0 other_end = other_start + length other_indexes.append( np.arange(other_start, other_end).clip(max=T - 1) ) batch_indexes_.append( np.zeros([length], dtype=np.int64) + batch_index ) batch_indexes = np.concatenate(batch_indexes_) other_indexes = np.concatenate(other_indexes) x[mask_indices] = x[batch_indexes, other_indexes] x = x.transpose(1, 2).contiguous() else: mask_indices = None if self.config.mask_channel_prob > 0: logging.info("No mask channel prob for input masking") return x, mask_indices def apply_feature_mask( self, x: torch.Tensor, padding_mask: torch.Tensor, target_list: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: B, T, C = x.shape assert all(( self.config.mask_prob_audio == self.config.mask_prob_image, self.config.mask_length_audio == self.config.mask_length_image, )), "masking prob/length for image/audio be same for feature masking" mask_prob = self.config.mask_prob_audio mask_length = self.config.mask_length_image if mask_prob > 0: mask_indices, _, _, _ = compute_mask_indices( (B, T), padding_mask, mask_prob, mask_length, self.config.mask_selection, self.config.mask_other, min_masks=2, no_overlap=self.config.no_mask_overlap, min_space=self.config.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb else: mask_indices = None if self.config.mask_channel_prob > 0: mask_channel_indices, _, _, _ = compute_mask_indices( (B, C), None, self.config.mask_channel_prob, self.config.mask_channel_length, self.config.mask_channel_selection, self.config.mask_channel_other, no_overlap=self.config.no_mask_channel_overlap, min_space=self.config.mask_channel_min_space, ) mask_channel_indices = ( torch.from_numpy(mask_channel_indices) .to(x.device) .unsqueeze(1) .expand(-1, T, -1) ) x[mask_channel_indices] = 0 return x, mask_indices def forward_features( self, source: Dict[str, torch.Tensor], modality: str, ) -> torch.Tensor: extractor = eval(f"self.feature_extractor_{modality}") if self.config.feature_grad_mult > 0: features = extractor(source) if self.config.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.config.feature_grad_mult) else: with torch.no_grad(): features = extractor(source) return features def forward_targets( self, features: torch.Tensor, mask_indices: torch.Tensor, target_list: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: # Trim features to ensure labels exist and then get aligned labels feat_tsz = features.size(2) targ_tsz = min([t.size(1) for t in target_list]) if self.feat2tar_ratio * feat_tsz > targ_tsz: feat_tsz = int(targ_tsz / self.feat2tar_ratio) features = features[..., :feat_tsz] if mask_indices is not None: mask_indices = mask_indices[..., :feat_tsz] target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio target_list = [t[:, target_inds.long()] for t in target_list] return features, mask_indices, target_list def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) return padding_mask def compute_logits(self, feats: torch.Tensor, emb_mat: torch.Tensor) -> torch.Tensor: # feats: [B, T, F], emb_mat: [V, F] if self.config.sim_type == "dot": logits = torch.matmul(feats, emb_mat.transpose(0, 1)) elif self.config.sim_type == "cosine": batch_size, timesteps, emb_dim = feats.size() feats_ = feats.view(-1, emb_dim) # [B*T, V] nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1) # [B*T, V] denom = ( (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1) * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0) ) logits = (nom / denom.clamp(min=1e-6)).view(batch_size, timesteps, -1) else: raise NotImplementedError logits = logits / self.config.logit_temp return logits def forward( self, source: Dict[str, torch.Tensor], target_list: Optional[List[torch.Tensor]] = None, padding_mask: Optional[torch.Tensor] = None, mask: bool = True, features_only: bool = False, output_layer: Optional[int] = None, ) -> Dict[str, torch.Tensor]: """output layer is 1-based""" src_audio, src_video = source["audio"], source["video"] if mask and self.masking_type == "input": src_video, mask_indices_video = self.apply_input_mask( src_video, padding_mask, target_list ) src_audio, mask_indices_audio = self.apply_input_mask( src_audio, padding_mask, target_list ) mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) else: src_audio, src_video, mask_indices = src_audio, src_video, None # [B, F, T] features_audio = self.forward_features(src_audio, modality="audio") features_video = self.forward_features(src_video, modality="video") if self.training: modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random() if modality_drop_prob < self.config.modality_dropout: if audio_drop_prob < self.config.audio_dropout: features_audio = 0 * features_audio else: features_video = 0 * features_video if self.config.modality_fuse == "concat": features = torch.cat([features_audio, features_video], dim=1) elif self.config.modality_fuse == "add": features = features_audio + features_video if target_list is not None: features, mask_indices, target_list = self.forward_targets( features, mask_indices, target_list ) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) if self.config.masking_type == "feature" and mask: x, mask_indices = self.apply_feature_mask( features, padding_mask, target_list ) else: x = features # feature: (B, T, D), float # target: (B, T), long # x: (B, T, D), float # padding_mask: (B, T), bool # mask_indices: (B, T), bool x, _ = self.encoder( x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1, ) if features_only: return {"x": x, "padding_mask": padding_mask, "features": features} label_embs_list = self.label_embs_concat.split(self.num_classes, 0) proj_x = self.final_proj(x) if self.config.untie_final_proj: proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1) else: proj_x_list = [proj_x for _ in self.num_classes] # [[B*T, V]] logit_list = [ self.compute_logits(proj, emb).view(-1, num_class) for proj, emb, num_class in zip( proj_x_list, label_embs_list, self.num_classes ) ] mask = torch.logical_and(mask_indices, ~padding_mask).view(-1) unmask = torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T] logit_m_list = [logit[mask] for logit in logit_list] logit_u_list = [logit[unmask] for logit in logit_list] target_m_list = [target.view(-1)[mask].long() for target in target_list] target_u_list = [target.view(-1)[unmask].long() for target in target_list] return { "logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "target_m_list": target_m_list, "target_u_list": target_u_list, "padding_mask": padding_mask, "features_pen": features_pen, } def extract_features( self, source: Dict[str, torch.Tensor], padding_mask: Optional[torch.Tensor] = None, mask: bool = False, ret_conv: bool = False, output_layer: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: res = self.forward( source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer, ) feature = res["features"] if ret_conv else res["x"] return feature, res["padding_mask"] def extract_units( self, source: Dict[str, torch.Tensor], padding_mask: torch.Tensor = None, mask: bool = False, ret_conv: bool = False, output_layer: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: res = self.forward( source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=None, ) feature = res["features"] if ret_conv else res["x"] proj_x = self.final_proj(feature) # B T units = ( torch .matmul(proj_x, self.label_embs_concat.transpose(0, 1)) .argmax(dim=-1) ) return units def extract_finetune( self, source: Dict[str, torch.Tensor], padding_mask: torch.Tensor = None, mask: bool = False, ret_conv: bool = False, output_layer: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: src_audio, src_video = source["audio"], source["video"] if mask and self.config.masking_type == "input": src_video, _ = self.apply_input_mask( src_video, padding_mask, target_list=None ) src_audio, _ = self.apply_input_mask( src_audio, padding_mask, target_list=None ) else: src_audio, src_video, _ = src_audio, src_video, None # features: [B, F, T] if src_audio is not None and src_video is None: features_audio = self.forward_features( src_audio, modality="audio" ) features_video = features_audio.new_zeros( features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1) ) elif src_audio is None and src_video is not None: features_video = self.forward_features(src_video, modality="video") features_audio = features_video.new_zeros( features_video.size(0), self.encoder_embed_dim, features_video.size(-1) ) elif src_audio is not None and src_video is not None: features_video = self.forward_features(src_video, modality="video") features_audio = self.forward_features( src_audio, modality="audio" ) if self.config.modality_fuse == "concat": features = torch.cat([features_audio, features_video], dim=1) elif self.config.modality_fuse == "add": features = features_audio + features_video features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) # feature: (B, T, D), float # target: (B, T), long # x: (B, T, D), float # padding_mask: (B, T), bool # mask_indices: (B, T), bool x, _ = self.encoder( features, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1, ) return x, padding_mask def get_extra_losses( self, net_output: Dict[str, torch.Tensor], ) -> Tuple[List[torch.Tensor], List[str]]: extra_losses = [] names = [] if "features_pen" in net_output: extra_losses.append(net_output["features_pen"]) names.append("features_pen") return extra_losses, names def remove_pretraining_modules(self) -> None: self.target_glu = None self.final_proj = None def compute_nce( self, x: torch.Tensor, pos: torch.Tensor, negs: torch.Tensor, ) -> torch.Tensor: neg_is_pos = (pos == negs).all(-1) pos = pos.unsqueeze(0) targets = torch.cat([pos, negs], dim=0) logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits /= self.config.logit_temp if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") logits = logits.transpose(0, 1) # (num_x, num_cls+1) return logits class HubertEncoderWrapper(nn.Module): def __init__( self, config: AVHubertConfig, dictionaries: List = [None], ) -> None: super().__init__() self.w2v_model = AVHubertModel(config, dictionaries) def forward( self, source: Dict[str, torch.Tensor], padding_mask: torch.Tensor, **kwargs, ) -> Dict[str, torch.Tensor]: w2v_args = { "source": source, "padding_mask": padding_mask, } x, padding_mask = self.w2v_model.extract_finetune(**w2v_args) return { "encoder_out": x, # T x B x C "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, } def reorder_encoder_out( self, encoder_out: Dict[str, torch.Tensor], new_order: torch.Tensor, ) -> Dict[str, torch.Tensor]: if encoder_out["encoder_out"] is not None: encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( 1, new_order ) if encoder_out["encoder_padding_mask"] is not None: encoder_out["encoder_padding_mask"] = encoder_out[ "encoder_padding_mask" ].index_select(0, new_order) if encoder_out["padding_mask"] is not None: encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select( 0, new_order ) return encoder_out class AVSPLLMModel(PreTrainedModel): config_class = AVSPLLMConfig def __init__( self, config: AVSPLLMConfig = AVSPLLMConfig(), dictionaries: List = [None], ) -> None: super().__init__(config=config) self.encoder = HubertEncoderWrapper(config, dictionaries) self.encoder.w2v_model.remove_pretraining_modules() self.avfeat_to_llm = nn.Linear( config.encoder_embed_dim, config.decoder_embed_dim ) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) decoder_4bit = AutoModelForCausalLM.from_pretrained( config.llm_ckpt_path, quantization_config=bnb_config, ) lora_config = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "v_proj", "k_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) self.decoder = get_peft_model(decoder_4bit, lora_config) self.decoder.print_trainable_parameters() def forward( self, source: Dict[str, torch.Tensor], target_list: torch.Tensor, padding_mask: torch.Tensor, **kwargs, ) -> CausalLMOutputWithPast: ft = self.config.freeze_finetune_updates <= kwargs.get("num_updates", -1) with torch.no_grad() if not ft else contextlib.ExitStack(): output = self.encoder(source, padding_mask, **kwargs) output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"]) cluster_counts = source["cluster_counts"][0] # tensor list results_tensor = [] start_idx = 0 for clutser_num in cluster_counts: end_idx = start_idx + clutser_num slice = output["encoder_out"][:, start_idx:end_idx, :] mean_tensor = torch.mean(slice, dim=1, keepdim=True) results_tensor.append(mean_tensor) start_idx = end_idx assert cluster_counts.sum().item() == output["encoder_out"].size()[1], \ f"{cluster_counts.sum().item()} != {output['encoder_out'].size()[1]}" reduced_enc_out = torch.cat(results_tensor, dim=1) B, T, D = reduced_enc_out.size() instruction = source["text"] instruction_embedding = self.decoder.model.model.embed_tokens(instruction) labels = target_list.clone() labels_embedding = self.decoder.model.model.embed_tokens(labels) llm_input = torch.cat( (instruction_embedding, reduced_enc_out, labels_embedding), dim=1 ) llm_labels = labels.clone() llm_labels[llm_labels == 0] = -100 _, instruction_embedding_t, _ = instruction_embedding.size() target_ids = ( torch.full((B, T + instruction_embedding_t), -100).long().to(labels.device) ) llm_labels = torch.cat((target_ids, llm_labels), dim=1) return self.decoder( inputs_embeds=llm_input, labels=llm_labels, return_dict=True ) @torch.no_grad() def generate( self, inputs: Optional[Dict[str, torch.Tensor]] = None, generation_config: Optional[GenerationConfig] = None, **kwargs, ) -> Any: output = self.encoder(**inputs) output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"]) cluster_counts = inputs["source"]["cluster_counts"][0] # tensor list results_tensor = [] start_idx = 0 for clutser_num in cluster_counts: end_idx = start_idx + clutser_num slice = output["encoder_out"][:, start_idx:end_idx, :] mean_tensor = torch.mean(slice, dim=1, keepdim=True) results_tensor.append(mean_tensor) start_idx = end_idx assert cluster_counts.sum().item() == output["encoder_out"].size()[1] reduced_enc_out = torch.cat(results_tensor, dim=1) B, T, D = reduced_enc_out.size() instruction = inputs["source"]["text"] instruction_embedding = self.decoder.model.model.embed_tokens(instruction) llm_input = torch.cat((instruction_embedding, reduced_enc_out), dim=1) self.decoder.config.use_cache = True return self.decoder.generate( inputs_embeds=llm_input, **generation_config, **kwargs, )