import math import warnings from typing import Union, Tuple, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import ( SequenceClassifierOutput, Wav2Vec2BaseModelOutput, Seq2SeqModelOutput, BaseModelOutput ) from transformers.cache_utils import ( Cache, DynamicCache, EncoderDecoderCache, StaticCache ) from transformers.models.whisper.modeling_whisper import ( WhisperEncoder, WhisperEncoderLayer, WhisperDecoderLayer, WhisperDecoder, _HIDDEN_STATES_START_POSITION ) from .configuration_whisper_spkreg import WhisperSpkRegConfig def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: """Returns sinusoids for positional embedding""" if channels % 2 != 0: raise ValueError( f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." ) log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, attention_mask: Optional[torch.LongTensor] = None, min_masks: int = 0, ) -> np.ndarray: """ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on CPU as part of the preprocessing during training. Args: shape: The shape for which to compute masks. This should be of a tuple of size 2 where the first element is the batch size and the second element is the length of the axis to span. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of independently generated mask spans of length `mask_length` is computed by `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the actual percentage will be smaller. mask_length: size of the mask min_masks: minimum number of masked spans attention_mask: A (right-padded) attention mask which independently shortens the feature axis of each batch dimension. """ batch_size, sequence_length = shape if mask_length < 1: raise ValueError("`mask_length` has to be bigger than 0.") if mask_length > sequence_length: raise ValueError( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" f" and `sequence_length`: {sequence_length}`" ) # epsilon is used for probabilistic rounding epsilon = np.random.rand(1).item() def compute_num_masked_span(input_length): """Given input length, compute how many spans should be masked""" num_masked_span = int(mask_prob * input_length / mask_length + epsilon) num_masked_span = max(num_masked_span, min_masks) # make sure num masked span <= sequence_length if num_masked_span * mask_length > sequence_length: num_masked_span = sequence_length // mask_length # make sure num_masked span is also <= input_length - (mask_length - 1) if input_length - (mask_length - 1) < num_masked_span: num_masked_span = max(input_length - (mask_length - 1), 0) return num_masked_span # compute number of masked spans in batch input_lengths = ( attention_mask.sum(-1).detach().tolist() if attention_mask is not None else [sequence_length for _ in range(batch_size)] ) # SpecAugment mask to fill spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) spec_aug_mask_idxs = [] max_num_masked_span = compute_num_masked_span(sequence_length) if max_num_masked_span == 0: return spec_aug_mask for input_length in input_lengths: # compute num of masked spans for this input num_masked_span = compute_num_masked_span(input_length) # get random indices to mask spec_aug_mask_idx = np.random.choice( np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False ) # pick first sampled index that will serve as a dummy index to pad vector # to ensure same dimension for all batches due to probabilistic rounding # Picking first sample just pads those vectors twice. if len(spec_aug_mask_idx) == 0: # this case can only happen if `input_length` is strictly smaller then # `sequence_length` in which case the last token has to be a padding # token which we can use as a dummy mask id dummy_mask_idx = sequence_length - 1 else: dummy_mask_idx = spec_aug_mask_idx[0] spec_aug_mask_idx = np.concatenate( [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] ) spec_aug_mask_idxs.append(spec_aug_mask_idx) spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) # expand masked indices to masked spans spec_aug_mask_idxs = np.broadcast_to( spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) ) spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) # add offset to the starting indexes so that indexes now create a span offsets = np.arange(mask_length)[None, None, :] offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( batch_size, max_num_masked_span * mask_length ) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length if spec_aug_mask_idxs.max() > sequence_length - 1: spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 # scatter indices to mask np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) return spec_aug_mask class WhisperSpkRegPreTrainedModel(PreTrainedModel): config_class = WhisperSpkRegConfig base_model_prefix = "model" main_input_name = "input_features" supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, WhisperEncoder): with torch.no_grad(): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers """ input_lengths = (input_lengths - 1) // 2 + 1 return input_lengths class WhisperSpkRegModel(WhisperSpkRegPreTrainedModel): def __init__(self, config: WhisperSpkRegConfig): super().__init__(config) self.encoder = WhisperEncoder(config) self.decoder = WhisperDecoder(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.decoder.embed_tokens def set_input_embeddings(self, value): self.decoder.embed_tokens = value def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def freeze_encoder(self): """ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will not be updated during training. """ self.encoder._freeze_parameters() def _mask_input_features( self, input_features: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, ): """ Masks extracted features along time axis and/or along feature axis according to [SpecAugment](https://arxiv.org/abs/1904.08779). """ # `config.apply_spec_augment` can set masking to False if not getattr(self.config, "apply_spec_augment", True): return input_features # generate indices & apply SpecAugment along time axis batch_size, hidden_size, sequence_length = input_features.size() if self.config.mask_time_prob > 0 and self.training: # generate indices & apply SpecAugment along time axis mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, attention_mask=attention_mask, min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool) mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1) input_features[mask_time_indices] = 0 if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, min_masks=self.config.mask_feature_min_masks, ) mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool) input_features[mask_feature_indices] = 0 return input_features def forward( self, input_features: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" Returns: Example: ```python >>> import torch >>> from transformers import AutoFeatureExtractor, WhisperModel >>> from datasets import load_dataset >>> model = WhisperModel.from_pretrained("openai/whisper-base") >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") >>> input_features = inputs.input_features >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state >>> list(last_hidden_state.shape) [1, 2, 512] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: input_features = self._mask_input_features(input_features, attention_mask=attention_mask) encoder_outputs = self.encoder( input_features, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class AngularLinear(nn.Module): def __init__(self, in_features: int, out_features: int): super(AngularLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = torch.nn.Parameter( torch.FloatTensor(out_features, in_features), requires_grad=True ) nn.init.xavier_normal_(self.weight, gain=1) def forward( self, inputs: torch.Tensor, ): # Calculation of cos(theta) cosine = F.linear(F.normalize(inputs), F.normalize(self.weight)) return cosine def extra_repr(self) -> str: return 'in_features={}, out_features={}'.format( self.in_features, self.out_features ) class AMSoftmaxLoss(nn.Module): """Additive Margin Softmax (CosFace). Paper: Wang, Feng, et al. "Additive margin softmax for face verification." IEEE Signal Processing Letters 25.7 (2018): 926-930. """ def __init__( self, scale: float = 30.0, margin: float = 0.35, label_smoothing: float = 0.0, reduction: str = "mean" ): """ Args: num_classes: Number of classes (output dimension) scale: Scaling factor for logits (default: 30.0) margin: Angular margin (default: 0.35) """ super(AMSoftmaxLoss, self).__init__() self.scale = scale self.margin = margin self.label_smoothing = label_smoothing self.reduction = reduction def forward( self, inputs: torch.Tensor, targets: torch.Tensor, ): """ Args: inputs: Input features of shape (batch_size, num_labels) targets: Ground truth labels of shape (batch_size) label_smoothing: Label smoothing factor (default: 0.0) reduction: Reduction method (default: "mean") Returns: Loss value """ _, num_labels = inputs.shape # `inputs` are the outputs from AngularLinear() cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7) psi = cos_theta - self.margin one_hot = nn.functional.one_hot(targets, num_labels) outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) loss = F.cross_entropy( outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction ) return loss class AAMSoftmaxLoss(nn.Module): """Additive Angular Margin Softmax (ArcFace). Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019. """ def __init__( self, scale: float = 30.0, margin: float = 0.2, easy_margin: bool = False, label_smoothing: float = 0.0, reduction: str = "mean" ): """ Args: num_classes: Number of classes (output dimension) scale: Scaling factor for logits (default: 30.0) margin: Angular margin (default: 0.35) easy_margin: Use the easy margin loss (default: False) """ super(AAMSoftmaxLoss, self).__init__() self.scale = scale self.margin = margin self.easy_margin = easy_margin self.label_smoothing = label_smoothing self.reduction = reduction def forward( self, inputs: torch.Tensor, targets: torch.Tensor, ): """ Args: inputs: Input features of shape (batch_size, num_labels) targets: Ground truth labels of shape (batch_size) Returns: Loss value """ _, num_labels = inputs.shape # `inputs` are the outputs from AngularLinear() epsilon = 1e-6 # theta = torch.acos(cos_theta) # psi = torch.cos(theta + self.margin) cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon) sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon) cos_m = math.cos(self.margin) sin_m = math.sin(self.margin) psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m) if self.easy_margin: psi = torch.where(cos_theta > 0, psi, cos_theta) else: # Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°] psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin) one_hot = nn.functional.one_hot(targets, num_labels) outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) loss = F.cross_entropy( outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction ) return loss class WhisperSpkRegForSequenceClassification(WhisperSpkRegPreTrainedModel): def __init__(self, config): super().__init__(config) self.encoder = WhisperEncoder(config) num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings if config.use_weighted_layer_sum: self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def freeze_encoder(self): """ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will not be updated during training. Only the projection layers and classification head will be updated. """ self.encoder._freeze_parameters() def get_input_embeddings(self) -> nn.Module: return self.encoder.get_input_embeddings() def set_input_embeddings(self, value: nn.Module): self.encoder.set_input_embeddings(value) def forward( self, input_features: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Returns: Example: ```python >>> import torch >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification >>> from datasets import load_dataset >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) >>> sample = next(iter(ds)) >>> inputs = feature_extractor( ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt" ... ) >>> input_features = inputs.input_features >>> with torch.no_grad(): ... logits = model(input_features).logits >>> predicted_class_ids = torch.argmax(logits).item() >>> predicted_label = model.config.id2label[predicted_class_ids] >>> predicted_label 'Afrikaans' ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if self.config.use_weighted_layer_sum: output_hidden_states = True elif output_hidden_states is None: output_hidden_states = self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( input_features, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.use_weighted_layer_sum: hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION] hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: hidden_states = encoder_outputs[0] hidden_states = self.projector(hidden_states) pooled_output = hidden_states.mean(dim=1) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.loss_fct == 'cross_entropy': loss_fct = nn.CrossEntropyLoss( label_smoothing=self.config.label_smoothing, reduction=self.config.reduction ) elif self.config.loss_fct == 'additive_margin': loss_fct = AMSoftmaxLoss( scale=self.config.scale, margin=self.config.margin, label_smoothing=self.config.label_smoothing, reduction=self.config.reduction ) elif self.config.loss_fct == 'additive_angular_margin': loss_fct = AAMSoftmaxLoss( scale=self.config.scale, margin=self.config.margin, easy_margin=self.config.easy_margin, label_smoothing=self.config.label_smoothing, reduction=self.config.reduction ) loss = loss_fct( logits.view(-1, self.config.num_labels), labels.view(-1).to(logits.device), ) if not return_dict: output = (logits,) + encoder_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )