|
""" |
|
Inference CTC class derived from HubertForCTC. |
|
|
|
Author: Marcely Zanon Boito, 2024 |
|
""" |
|
|
|
|
|
from typing import Optional, Tuple, Union |
|
import torch |
|
from torch import nn |
|
from transformers import HubertPreTrainedModel, HubertModel |
|
from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput |
|
|
|
class VanillaNN(nn.Module): |
|
def __init__(self, input_dim, output_dim): |
|
""" |
|
simple NN with ReLU activation (no norm) |
|
""" |
|
super().__init__() |
|
|
|
self.linear = nn.Linear(input_dim, output_dim) |
|
self.act_fn = nn.ReLU() |
|
|
|
def forward(self, hidden_states: torch.FloatTensor): |
|
hidden_states = self.linear(hidden_states) |
|
hidden_states = self.act_fn(hidden_states) |
|
|
|
return hidden_states |
|
|
|
class mHubertForCTC(HubertPreTrainedModel): |
|
def __init__(self, config, target_lang: Optional[str] = None): |
|
super().__init__(config) |
|
self.hubert = HubertModel(config) |
|
self.dropout = nn.Dropout(config.final_dropout) |
|
|
|
output_hidden_size = config.hidden_size |
|
|
|
self.has_interface = config.add_interface |
|
|
|
|
|
if config.add_interface: |
|
self.interface = nn.ModuleList([VanillaNN(output_hidden_size,output_hidden_size) for i in range(config.num_interface_layers)]) |
|
|
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
) -> Union[Tuple, SequenceClassifierOutput]: |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
output_hidden_states = self.config.output_hidden_states |
|
|
|
outputs = self.hubert( |
|
input_values, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = outputs[0] |
|
|
|
hidden_states = self.dropout(hidden_states) |
|
if self.has_interface: |
|
for layer in self.interface: |
|
hidden_states = layer(hidden_states) |
|
logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
|
|
if labels is not None: |
|
if labels.max() >= self.config.vocab_size: |
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") |
|
|
|
|
|
attention_mask = ( |
|
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) |
|
) |
|
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) |
|
|
|
|
|
|
|
labels_mask = labels >= 0 |
|
target_lengths = labels_mask.sum(-1) |
|
flattened_targets = labels.masked_select(labels_mask) |
|
|
|
|
|
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) |
|
|
|
with torch.backends.cudnn.flags(enabled=False): |
|
loss = nn.functional.ctc_loss( |
|
log_probs, |
|
flattened_targets, |
|
input_lengths, |
|
target_lengths, |
|
blank=self.config.ctc_token_id, |
|
reduction=self.config.ctc_loss_reduction, |
|
zero_infinity=self.config.ctc_zero_infinity, |
|
) |
|
|
|
return CausalLMOutput( |
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions |
|
) |