|
import torch |
|
import logging |
|
import contextlib |
|
import numpy as np |
|
import torch.nn as nn |
|
from .resnet import ResNetEncoder |
|
from .utils import compute_mask_indices |
|
from .encoder import TransformerEncoder |
|
from .configuration import AVHubertConfig, AVSPLLMConfig |
|
from typing import Optional, Tuple, List, Dict, Any |
|
from peft import get_peft_model, LoraConfig |
|
from fairseq.modules import GradMultiply, LayerNorm |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers import ( |
|
FeatureExtractionMixin, |
|
PreTrainedModel, |
|
BitsAndBytesConfig, |
|
AutoModelForCausalLM, |
|
GenerationConfig, |
|
) |
|
|
|
|
|
class AVHubertFeatureExtractor(FeatureExtractionMixin): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class AVSPLLMFeatureExtractor(AVHubertFeatureExtractor): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class AVHubertVideoFeatureEncoder(nn.Module): |
|
def __init__(self, config: AVHubertConfig) -> None: |
|
super().__init__() |
|
self.resnet = ResNetEncoder(relu_type=config.resnet_relu_type) |
|
self.proj = nn.Linear(self.resnet.backend_out, config.encoder_embed_dim) |
|
self.encoder = ( |
|
TransformerEncoder(config) |
|
if config.sub_encoder_layers > 0 |
|
else None |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.resnet(x) |
|
x = self.proj(x.transpose(1, 2)) |
|
if self.encoder is not None: |
|
x = self.encoder(x)[0].transpose(1, 2) |
|
else: |
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
|
|
class AVHubertAudioFeatureEncoder(nn.Module): |
|
def __init__(self, config: AVHubertConfig) -> None: |
|
super().__init__() |
|
self.proj = nn.Linear(config.audio_feat_dim, config.encoder_embed_dim) |
|
self.encoder = ( |
|
TransformerEncoder(config) |
|
if config.sub_encoder_layers > 0 |
|
else None |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.proj(x.transpose(1, 2)) |
|
if self.encoder is not None: |
|
x = self.encoder(x)[0].transpose(1, 2) |
|
else: |
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
|
|
class AVHubertModel(PreTrainedModel): |
|
config_class = AVHubertConfig |
|
|
|
def __init__( |
|
self, |
|
config: AVHubertConfig = AVHubertConfig(), |
|
dictionaries: List = [None], |
|
) -> None: |
|
super().__init__(config=config) |
|
label_rate = config.label_rate |
|
feature_ds_rate = config.feature_ds_rate |
|
sample_rate = config.sample_rate |
|
self.feat2tar_ration = label_rate * feature_ds_rate / sample_rate |
|
|
|
self.feature_extractor_video = AVHubertVideoFeatureEncoder(config) |
|
self.feature_extractor_audio = AVHubertAudioFeatureEncoder(config) |
|
|
|
if config.modality_fuse == "concat": |
|
self.encoder_embed_dim = config.encoder_embed_dim * 2 |
|
elif config.modality_fuse == "add": |
|
self.encoder_embed_dim = config.encoder_embed_dim |
|
|
|
self.post_extract_proj = ( |
|
nn.Linear(self.encoder_embed_dim, config.encoder_embed_dim) |
|
if self.encoder_embed_dim != config.encoder_embed_dim |
|
else None |
|
) |
|
|
|
self.dropout_input = nn.Dropout(config.dropout_input) |
|
self.dropout_features = nn.Dropout(config.dropout_features) |
|
|
|
if self.config.final_dim > 0: |
|
final_dim = config.final_dim |
|
else: |
|
final_dim = config.encoder_embed_dim |
|
|
|
self.mask_emb = nn.Parameter( |
|
torch.FloatTensor(config.audio_feat_dim).uniform_() |
|
if config.masking_type == "input" |
|
else torch.FloatTensor(config.encoder_embed_dim).uniform_() |
|
) |
|
|
|
self.encoder = TransformerEncoder(self.config) |
|
self.layer_norm = LayerNorm(self.encoder_embed_dim) |
|
|
|
self.target_glu = None |
|
if config.target_glu: |
|
self.target_glu = nn.Sequential( |
|
nn.Linear(config.final_dim, config.final_dim * 2), |
|
nn.GLU(), |
|
) |
|
|
|
if config.untie_final_proj: |
|
self.final_proj = nn.Linear( |
|
config.encoder_embed_dim, |
|
final_dim * len(dictionaries), |
|
) |
|
else: |
|
self.final_proj = nn.Linear(config.encoder_embed_dim, final_dim) |
|
|
|
|
|
if any([d is None for d in dictionaries]): |
|
self.num_classes = config.num_classes |
|
else: |
|
self.num_classes = sum([len(d) for d in dictionaries]) |
|
self.label_embs_concat = nn.Parameter( |
|
torch.FloatTensor(self.num_classes, final_dim) |
|
) |
|
nn.init.uniform_(self.label_embs_concat) |
|
|
|
def apply_input_mask( |
|
self, |
|
x: torch.Tensor, |
|
padding_mask: torch.Tensor, |
|
target_list: List[torch.Tensor], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
B, C, T = x.shape[:3] |
|
is_audio = True if len(x.shape) == 3 else False |
|
|
|
if is_audio: |
|
mask_prob = self.config.mask_prob_audio |
|
mask_length = self.config.mask_length_audio |
|
else: |
|
mask_prob = self.config.mask_prob_image |
|
mask_length = self.config.mask_length_image |
|
|
|
if mask_prob > 0: |
|
mask_indices, starts, ends, batch_indexes = compute_mask_indices( |
|
(B, T), |
|
padding_mask, |
|
mask_prob, |
|
mask_length, |
|
self.config.mask_selection, |
|
self.config.mask_other, |
|
min_masks=2, |
|
no_overlap=self.config.no_mask_overlap, |
|
min_space=self.config.mask_min_space, |
|
) |
|
mask_indices = torch.from_numpy(mask_indices).to(x.device) |
|
x = x.transpose(1, 2).contiguous() |
|
if B == 1: |
|
x[mask_indices] = 0 |
|
elif is_audio: |
|
x[mask_indices] = self.mask_emb |
|
elif self.config.selection_type == "same_other_seq": |
|
perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B |
|
x_perm = x[perm] |
|
x[mask_indices] = x_perm[mask_indices] |
|
elif self.config.selection_type == "same_seq": |
|
batch_indexes_, other_indexes = [], [] |
|
for batch_index, start, end in zip(batch_indexes, starts, ends): |
|
length = end - start |
|
other_start = np.setdiff1d( |
|
np.arange(T), np.arange(max(0, start - length), end) |
|
) |
|
if len(other_start) > 0: |
|
other_start = np.random.choice(other_start, size=1) |
|
else: |
|
other_start = 0 |
|
other_end = other_start + length |
|
other_indexes.append( |
|
np.arange(other_start, other_end).clip(max=T - 1) |
|
) |
|
batch_indexes_.append( |
|
np.zeros([length], dtype=np.int64) + batch_index |
|
) |
|
batch_indexes = np.concatenate(batch_indexes_) |
|
other_indexes = np.concatenate(other_indexes) |
|
x[mask_indices] = x[batch_indexes, other_indexes] |
|
x = x.transpose(1, 2).contiguous() |
|
else: |
|
mask_indices = None |
|
|
|
if self.config.mask_channel_prob > 0: |
|
logging.info("No mask channel prob for input masking") |
|
return x, mask_indices |
|
|
|
def apply_feature_mask( |
|
self, |
|
x: torch.Tensor, |
|
padding_mask: torch.Tensor, |
|
target_list: List[torch.Tensor], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
B, T, C = x.shape |
|
assert all(( |
|
self.config.mask_prob_audio == self.config.mask_prob_image, |
|
self.config.mask_length_audio == self.config.mask_length_image, |
|
)), "masking prob/length for image/audio be same for feature masking" |
|
|
|
mask_prob = self.config.mask_prob_audio |
|
mask_length = self.config.mask_length_image |
|
if mask_prob > 0: |
|
mask_indices, _, _, _ = compute_mask_indices( |
|
(B, T), |
|
padding_mask, |
|
mask_prob, |
|
mask_length, |
|
self.config.mask_selection, |
|
self.config.mask_other, |
|
min_masks=2, |
|
no_overlap=self.config.no_mask_overlap, |
|
min_space=self.config.mask_min_space, |
|
) |
|
mask_indices = torch.from_numpy(mask_indices).to(x.device) |
|
x[mask_indices] = self.mask_emb |
|
else: |
|
mask_indices = None |
|
|
|
if self.config.mask_channel_prob > 0: |
|
mask_channel_indices, _, _, _ = compute_mask_indices( |
|
(B, C), |
|
None, |
|
self.config.mask_channel_prob, |
|
self.config.mask_channel_length, |
|
self.config.mask_channel_selection, |
|
self.config.mask_channel_other, |
|
no_overlap=self.config.no_mask_channel_overlap, |
|
min_space=self.config.mask_channel_min_space, |
|
) |
|
mask_channel_indices = ( |
|
torch.from_numpy(mask_channel_indices) |
|
.to(x.device) |
|
.unsqueeze(1) |
|
.expand(-1, T, -1) |
|
) |
|
x[mask_channel_indices] = 0 |
|
|
|
return x, mask_indices |
|
|
|
def forward_features( |
|
self, |
|
source: Dict[str, torch.Tensor], |
|
modality: str, |
|
) -> torch.Tensor: |
|
extractor = eval(f"self.feature_extractor_{modality}") |
|
if self.config.feature_grad_mult > 0: |
|
features = extractor(source) |
|
if self.config.feature_grad_mult != 1.0: |
|
features = GradMultiply.apply(features, self.config.feature_grad_mult) |
|
else: |
|
with torch.no_grad(): |
|
features = extractor(source) |
|
return features |
|
|
|
def forward_targets( |
|
self, |
|
features: torch.Tensor, |
|
mask_indices: torch.Tensor, |
|
target_list: List[torch.Tensor], |
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: |
|
|
|
feat_tsz = features.size(2) |
|
targ_tsz = min([t.size(1) for t in target_list]) |
|
if self.feat2tar_ratio * feat_tsz > targ_tsz: |
|
feat_tsz = int(targ_tsz / self.feat2tar_ratio) |
|
features = features[..., :feat_tsz] |
|
if mask_indices is not None: |
|
mask_indices = mask_indices[..., :feat_tsz] |
|
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio |
|
target_list = [t[:, target_inds.long()] for t in target_list] |
|
return features, mask_indices, target_list |
|
|
|
def forward_padding_mask( |
|
self, |
|
features: torch.Tensor, |
|
padding_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
extra = padding_mask.size(1) % features.size(1) |
|
if extra > 0: |
|
padding_mask = padding_mask[:, :-extra] |
|
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) |
|
padding_mask = padding_mask.all(-1) |
|
return padding_mask |
|
|
|
def compute_logits(self, feats: torch.Tensor, emb_mat: torch.Tensor) -> torch.Tensor: |
|
|
|
if self.config.sim_type == "dot": |
|
logits = torch.matmul(feats, emb_mat.transpose(0, 1)) |
|
elif self.config.sim_type == "cosine": |
|
batch_size, timesteps, emb_dim = feats.size() |
|
feats_ = feats.view(-1, emb_dim) |
|
|
|
nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1) |
|
|
|
denom = ( |
|
(feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1) |
|
* (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0) |
|
) |
|
logits = (nom / denom.clamp(min=1e-6)).view(batch_size, timesteps, -1) |
|
else: |
|
raise NotImplementedError |
|
logits = logits / self.config.logit_temp |
|
return logits |
|
|
|
def forward( |
|
self, |
|
source: Dict[str, torch.Tensor], |
|
target_list: Optional[List[torch.Tensor]] = None, |
|
padding_mask: Optional[torch.Tensor] = None, |
|
mask: bool = True, |
|
features_only: bool = False, |
|
output_layer: Optional[int] = None, |
|
) -> Dict[str, torch.Tensor]: |
|
"""output layer is 1-based""" |
|
src_audio, src_video = source["audio"], source["video"] |
|
if mask and self.masking_type == "input": |
|
src_video, mask_indices_video = self.apply_input_mask( |
|
src_video, padding_mask, target_list |
|
) |
|
src_audio, mask_indices_audio = self.apply_input_mask( |
|
src_audio, padding_mask, target_list |
|
) |
|
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) |
|
else: |
|
src_audio, src_video, mask_indices = src_audio, src_video, None |
|
|
|
|
|
features_audio = self.forward_features(src_audio, modality="audio") |
|
features_video = self.forward_features(src_video, modality="video") |
|
|
|
if self.training: |
|
modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random() |
|
if modality_drop_prob < self.config.modality_dropout: |
|
if audio_drop_prob < self.config.audio_dropout: |
|
features_audio = 0 * features_audio |
|
else: |
|
features_video = 0 * features_video |
|
|
|
if self.config.modality_fuse == "concat": |
|
features = torch.cat([features_audio, features_video], dim=1) |
|
elif self.config.modality_fuse == "add": |
|
features = features_audio + features_video |
|
|
|
if target_list is not None: |
|
features, mask_indices, target_list = self.forward_targets( |
|
features, mask_indices, target_list |
|
) |
|
|
|
features_pen = features.float().pow(2).mean() |
|
|
|
features = features.transpose(1, 2) |
|
features = self.layer_norm(features) |
|
|
|
if padding_mask is not None: |
|
padding_mask = self.forward_padding_mask(features, padding_mask) |
|
|
|
if self.post_extract_proj is not None: |
|
features = self.post_extract_proj(features) |
|
|
|
features = self.dropout_input(features) |
|
if self.config.masking_type == "feature" and mask: |
|
x, mask_indices = self.apply_feature_mask( |
|
features, padding_mask, target_list |
|
) |
|
else: |
|
x = features |
|
|
|
|
|
|
|
|
|
|
|
|
|
x, _ = self.encoder( |
|
x, |
|
padding_mask=padding_mask, |
|
layer=None if output_layer is None else output_layer - 1, |
|
) |
|
|
|
if features_only: |
|
return {"x": x, "padding_mask": padding_mask, "features": features} |
|
|
|
label_embs_list = self.label_embs_concat.split(self.num_classes, 0) |
|
proj_x = self.final_proj(x) |
|
if self.config.untie_final_proj: |
|
proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1) |
|
else: |
|
proj_x_list = [proj_x for _ in self.num_classes] |
|
|
|
|
|
logit_list = [ |
|
self.compute_logits(proj, emb).view(-1, num_class) |
|
for proj, emb, num_class in zip( |
|
proj_x_list, label_embs_list, self.num_classes |
|
) |
|
] |
|
|
|
mask = torch.logical_and(mask_indices, ~padding_mask).view(-1) |
|
unmask = torch.logical_and(~mask_indices, ~padding_mask).view(-1) |
|
logit_m_list = [logit[mask] for logit in logit_list] |
|
logit_u_list = [logit[unmask] for logit in logit_list] |
|
target_m_list = [target.view(-1)[mask].long() for target in target_list] |
|
target_u_list = [target.view(-1)[unmask].long() for target in target_list] |
|
|
|
return { |
|
"logit_m_list": logit_m_list, |
|
"logit_u_list": logit_u_list, |
|
"target_m_list": target_m_list, |
|
"target_u_list": target_u_list, |
|
"padding_mask": padding_mask, |
|
"features_pen": features_pen, |
|
} |
|
|
|
def extract_features( |
|
self, |
|
source: Dict[str, torch.Tensor], |
|
padding_mask: Optional[torch.Tensor] = None, |
|
mask: bool = False, |
|
ret_conv: bool = False, |
|
output_layer: Optional[int] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
res = self.forward( |
|
source, |
|
padding_mask=padding_mask, |
|
mask=mask, |
|
features_only=True, |
|
output_layer=output_layer, |
|
) |
|
feature = res["features"] if ret_conv else res["x"] |
|
return feature, res["padding_mask"] |
|
|
|
def extract_units( |
|
self, |
|
source: Dict[str, torch.Tensor], |
|
padding_mask: torch.Tensor = None, |
|
mask: bool = False, |
|
ret_conv: bool = False, |
|
output_layer: Optional[int] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
res = self.forward( |
|
source, |
|
padding_mask=padding_mask, |
|
mask=mask, |
|
features_only=True, |
|
output_layer=None, |
|
) |
|
|
|
feature = res["features"] if ret_conv else res["x"] |
|
proj_x = self.final_proj(feature) |
|
|
|
units = ( |
|
torch |
|
.matmul(proj_x, self.label_embs_concat.transpose(0, 1)) |
|
.argmax(dim=-1) |
|
) |
|
return units |
|
|
|
def extract_finetune( |
|
self, |
|
source: Dict[str, torch.Tensor], |
|
padding_mask: torch.Tensor = None, |
|
mask: bool = False, |
|
ret_conv: bool = False, |
|
output_layer: Optional[int] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
src_audio, src_video = source["audio"], source["video"] |
|
if mask and self.config.masking_type == "input": |
|
src_video, _ = self.apply_input_mask( |
|
src_video, padding_mask, target_list=None |
|
) |
|
src_audio, _ = self.apply_input_mask( |
|
src_audio, padding_mask, target_list=None |
|
) |
|
else: |
|
src_audio, src_video, _ = src_audio, src_video, None |
|
|
|
|
|
if src_audio is not None and src_video is None: |
|
features_audio = self.forward_features( |
|
src_audio, modality="audio" |
|
) |
|
features_video = features_audio.new_zeros( |
|
features_audio.size(0), |
|
self.encoder_embed_dim, |
|
features_audio.size(-1) |
|
) |
|
elif src_audio is None and src_video is not None: |
|
features_video = self.forward_features(src_video, modality="video") |
|
features_audio = features_video.new_zeros( |
|
features_video.size(0), |
|
self.encoder_embed_dim, |
|
features_video.size(-1) |
|
) |
|
elif src_audio is not None and src_video is not None: |
|
features_video = self.forward_features(src_video, modality="video") |
|
features_audio = self.forward_features( |
|
src_audio, modality="audio" |
|
) |
|
|
|
if self.config.modality_fuse == "concat": |
|
features = torch.cat([features_audio, features_video], dim=1) |
|
elif self.config.modality_fuse == "add": |
|
features = features_audio + features_video |
|
|
|
features = features.transpose(1, 2) |
|
features = self.layer_norm(features) |
|
unmasked_features = features.clone() |
|
|
|
if padding_mask is not None: |
|
padding_mask = self.forward_padding_mask(features, padding_mask) |
|
|
|
if self.post_extract_proj is not None: |
|
features = self.post_extract_proj(features) |
|
|
|
features = self.dropout_input(features) |
|
unmasked_features = self.dropout_features(unmasked_features) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x, _ = self.encoder( |
|
features, |
|
padding_mask=padding_mask, |
|
layer=None if output_layer is None else output_layer - 1, |
|
) |
|
|
|
return x, padding_mask |
|
|
|
def get_extra_losses( |
|
self, |
|
net_output: Dict[str, torch.Tensor], |
|
) -> Tuple[List[torch.Tensor], List[str]]: |
|
extra_losses = [] |
|
names = [] |
|
if "features_pen" in net_output: |
|
extra_losses.append(net_output["features_pen"]) |
|
names.append("features_pen") |
|
|
|
return extra_losses, names |
|
|
|
def remove_pretraining_modules(self) -> None: |
|
self.target_glu = None |
|
self.final_proj = None |
|
|
|
def compute_nce( |
|
self, |
|
x: torch.Tensor, |
|
pos: torch.Tensor, |
|
negs: torch.Tensor, |
|
) -> torch.Tensor: |
|
neg_is_pos = (pos == negs).all(-1) |
|
pos = pos.unsqueeze(0) |
|
targets = torch.cat([pos, negs], dim=0) |
|
|
|
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) |
|
logits /= self.config.logit_temp |
|
if neg_is_pos.any(): |
|
logits[1:][neg_is_pos] = float("-inf") |
|
logits = logits.transpose(0, 1) |
|
return logits |
|
|
|
|
|
class HubertEncoderWrapper(nn.Module): |
|
def __init__( |
|
self, |
|
config: AVHubertConfig, |
|
dictionaries: List = [None], |
|
) -> None: |
|
super().__init__() |
|
self.w2v_model = AVHubertModel(config, dictionaries) |
|
|
|
def forward( |
|
self, |
|
source: Dict[str, torch.Tensor], |
|
padding_mask: torch.Tensor, |
|
**kwargs, |
|
) -> Dict[str, torch.Tensor]: |
|
w2v_args = { |
|
"source": source, |
|
"padding_mask": padding_mask, |
|
} |
|
x, padding_mask = self.w2v_model.extract_finetune(**w2v_args) |
|
return { |
|
"encoder_out": x, |
|
"encoder_padding_mask": padding_mask, |
|
"padding_mask": padding_mask, |
|
} |
|
|
|
def reorder_encoder_out( |
|
self, |
|
encoder_out: Dict[str, torch.Tensor], |
|
new_order: torch.Tensor, |
|
) -> Dict[str, torch.Tensor]: |
|
if encoder_out["encoder_out"] is not None: |
|
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( |
|
1, new_order |
|
) |
|
if encoder_out["encoder_padding_mask"] is not None: |
|
encoder_out["encoder_padding_mask"] = encoder_out[ |
|
"encoder_padding_mask" |
|
].index_select(0, new_order) |
|
if encoder_out["padding_mask"] is not None: |
|
encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select( |
|
0, new_order |
|
) |
|
return encoder_out |
|
|
|
|
|
class AVSPLLMModel(PreTrainedModel): |
|
config_class = AVSPLLMConfig |
|
|
|
def __init__( |
|
self, |
|
config: AVSPLLMConfig = AVSPLLMConfig(), |
|
dictionaries: List = [None], |
|
) -> None: |
|
super().__init__(config=config) |
|
self.encoder = HubertEncoderWrapper(config, dictionaries) |
|
self.encoder.w2v_model.remove_pretraining_modules() |
|
|
|
self.avfeat_to_llm = nn.Linear( |
|
config.encoder_embed_dim, config.decoder_embed_dim |
|
) |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
decoder_4bit = AutoModelForCausalLM.from_pretrained( |
|
config.llm_ckpt_path, |
|
quantization_config=bnb_config, |
|
) |
|
lora_config = LoraConfig( |
|
r=16, |
|
lora_alpha=32, |
|
target_modules=["q_proj", "v_proj", "k_proj"], |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
self.decoder = get_peft_model(decoder_4bit, lora_config) |
|
self.decoder.print_trainable_parameters() |
|
|
|
def forward( |
|
self, |
|
source: Dict[str, torch.Tensor], |
|
target_list: torch.Tensor, |
|
padding_mask: torch.Tensor, |
|
**kwargs, |
|
) -> CausalLMOutputWithPast: |
|
ft = self.config.freeze_finetune_updates <= kwargs.get("num_updates", -1) |
|
with torch.no_grad() if not ft else contextlib.ExitStack(): |
|
output = self.encoder(source, padding_mask, **kwargs) |
|
|
|
output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"]) |
|
cluster_counts = source["cluster_counts"][0] |
|
|
|
results_tensor = [] |
|
start_idx = 0 |
|
for clutser_num in cluster_counts: |
|
end_idx = start_idx + clutser_num |
|
slice = output["encoder_out"][:, start_idx:end_idx, :] |
|
mean_tensor = torch.mean(slice, dim=1, keepdim=True) |
|
results_tensor.append(mean_tensor) |
|
start_idx = end_idx |
|
|
|
assert cluster_counts.sum().item() == output["encoder_out"].size()[1], \ |
|
f"{cluster_counts.sum().item()} != {output['encoder_out'].size()[1]}" |
|
|
|
reduced_enc_out = torch.cat(results_tensor, dim=1) |
|
B, T, D = reduced_enc_out.size() |
|
|
|
instruction = source["text"] |
|
instruction_embedding = self.decoder.model.model.embed_tokens(instruction) |
|
|
|
labels = target_list.clone() |
|
labels_embedding = self.decoder.model.model.embed_tokens(labels) |
|
|
|
llm_input = torch.cat( |
|
(instruction_embedding, reduced_enc_out, labels_embedding), dim=1 |
|
) |
|
llm_labels = labels.clone() |
|
llm_labels[llm_labels == 0] = -100 |
|
|
|
_, instruction_embedding_t, _ = instruction_embedding.size() |
|
target_ids = ( |
|
torch.full((B, T + instruction_embedding_t), -100).long().to(labels.device) |
|
) |
|
llm_labels = torch.cat((target_ids, llm_labels), dim=1) |
|
return self.decoder( |
|
inputs_embeds=llm_input, labels=llm_labels, return_dict=True |
|
) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
inputs: Optional[Dict[str, torch.Tensor]] = None, |
|
generation_config: Optional[GenerationConfig] = None, |
|
**kwargs, |
|
) -> Any: |
|
output = self.encoder(**inputs) |
|
output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"]) |
|
cluster_counts = inputs["source"]["cluster_counts"][0] |
|
|
|
results_tensor = [] |
|
start_idx = 0 |
|
|
|
for clutser_num in cluster_counts: |
|
end_idx = start_idx + clutser_num |
|
slice = output["encoder_out"][:, start_idx:end_idx, :] |
|
mean_tensor = torch.mean(slice, dim=1, keepdim=True) |
|
results_tensor.append(mean_tensor) |
|
start_idx = end_idx |
|
|
|
assert cluster_counts.sum().item() == output["encoder_out"].size()[1] |
|
|
|
reduced_enc_out = torch.cat(results_tensor, dim=1) |
|
B, T, D = reduced_enc_out.size() |
|
instruction = inputs["source"]["text"] |
|
instruction_embedding = self.decoder.model.model.embed_tokens(instruction) |
|
llm_input = torch.cat((instruction_embedding, reduced_enc_out), dim=1) |
|
|
|
self.decoder.config.use_cache = True |
|
return self.decoder.generate( |
|
inputs_embeds=llm_input, |
|
**generation_config, |
|
**kwargs, |
|
) |
|
|