mzboito's picture
files upload
ffa317c
raw
history blame
4.05 kB
"""
Inference CTC class derived from HubertForCTC.
Author: Marcely Zanon Boito, 2024
"""
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import HubertPreTrainedModel, HubertModel
from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput
class VanillaNN(nn.Module):
def __init__(self, input_dim, output_dim):
"""
simple NN with ReLU activation (no norm)
"""
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.act_fn = nn.ReLU()
def forward(self, hidden_states: torch.FloatTensor):
hidden_states = self.linear(hidden_states)
hidden_states = self.act_fn(hidden_states)
return hidden_states
class mHubertForCTC(HubertPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.hubert = HubertModel(config)
self.dropout = nn.Dropout(config.final_dropout)
output_hidden_size = config.hidden_size
self.has_interface = config.add_interface
# NN layers on top of the trainable stack
if config.add_interface:
self.interface = nn.ModuleList([VanillaNN(output_hidden_size,output_hidden_size) for i in range(config.num_interface_layers)])
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
self.post_init()
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = self.config.output_hidden_states
outputs = self.hubert(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
if self.has_interface:
for layer in self.interface:
hidden_states = layer(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.ctc_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
return CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)