|
import os
|
|
import torch
|
|
import math
|
|
from torch import nn
|
|
from typing import Optional, Tuple, Union
|
|
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput, CausalLMOutput
|
|
from safetensors.torch import load_file
|
|
|
|
from .configuration_bestrq_conformer import MeralionBestRqConformerEncoderConfig
|
|
|
|
|
|
_HIDDEN_STATES_START_POSITION = 2
|
|
|
|
def lengths_to_padding_mask(lens: torch.LongTensor)-> torch.BoolTensor:
|
|
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
|
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
|
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
|
return mask
|
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
"""Make mask tensor containing indices of padded part.
|
|
|
|
See description of make_non_pad_mask.
|
|
|
|
Args:
|
|
lengths (torch.Tensor): Batch of lengths (B,).
|
|
Returns:
|
|
torch.Tensor: Mask tensor containing indices of padded part.
|
|
|
|
Examples:
|
|
>>> lengths = [5, 3, 2]
|
|
>>> make_pad_mask(lengths)
|
|
masks = [[0, 0, 0, 0 ,0],
|
|
[0, 0, 0, 1, 1],
|
|
[0, 0, 1, 1, 1]]
|
|
"""
|
|
batch_size = lengths.size(0)
|
|
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
seq_range = torch.arange(0,
|
|
max_len,
|
|
dtype=torch.int64,
|
|
device=lengths.device)
|
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
seq_length_expand = lengths.unsqueeze(-1)
|
|
mask = seq_range_expand >= seq_length_expand
|
|
return mask
|
|
|
|
|
|
class Conv2dSubsampling(nn.Module):
|
|
"""
|
|
Convolutional 2D subsampling (to 1/4 length)
|
|
For feature extraction/downsampling of input mel spectrogram
|
|
|
|
Args:
|
|
in_channels (int): Number of channels in the input image
|
|
out_channels (int): Number of channels produced by the convolution
|
|
|
|
Inputs:
|
|
inputs (batch, time, dim): Tensor containing sequence of inputs
|
|
input_lengths (batch): Tensor containing input_length for each item in batch
|
|
|
|
Returns:
|
|
outputs (batch, time, dim): Tensor produced by the convolution
|
|
output_lengths (batch): Tensor containing output_length for each item in batch
|
|
"""
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.sequential = nn.Sequential(
|
|
nn.Conv2d(config.input_channels, config.hidden_size, kernel_size=3, stride=2),
|
|
nn.ReLU(),
|
|
nn.Conv2d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
|
|
_, max_seq_len, _ = inputs.size()
|
|
outputs = self.sequential(inputs.unsqueeze(1))
|
|
batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
|
|
|
|
outputs = outputs.permute(0, 2, 1, 3)
|
|
outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim)
|
|
|
|
subsampling_factor = int(max_seq_len * 1.0 / subsampled_lengths + 0.5)
|
|
input_len_0 = (input_lengths.float() / subsampling_factor).ceil().long()
|
|
input_len_1 = outputs.size(1) * torch.ones([input_lengths.size(0)]).long().to(
|
|
input_len_0.device
|
|
)
|
|
output_lengths = torch.min(input_len_0, input_len_1)
|
|
|
|
return outputs, output_lengths
|
|
|
|
|
|
class ConformerRelPositionalEmbedding(nn.Module):
|
|
"""Relative positional encoding module (new implementation).
|
|
|
|
Args:
|
|
d_model: Embedding dimension.
|
|
dropout_rate: Dropout rate.
|
|
max_len: Maximum input length.
|
|
"""
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.max_len = config.max_source_positions
|
|
self.d_model = config.hidden_size
|
|
self.pe = None
|
|
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
|
|
|
|
def extend_pe(self, x):
|
|
"""Reset the positional encodings."""
|
|
if self.pe is not None:
|
|
|
|
|
|
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
|
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
|
return
|
|
|
|
|
|
|
|
pe_positive = torch.zeros(x.size(1), self.d_model)
|
|
pe_negative = torch.zeros(x.size(1), self.d_model)
|
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
|
* -(math.log(10000.0) / self.d_model)
|
|
)
|
|
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
|
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
|
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
|
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
|
|
|
|
|
|
|
|
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
|
pe_negative = pe_negative[1:].unsqueeze(0)
|
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""Add positional encoding.
|
|
Args:
|
|
x : Input tensor T X B X C.
|
|
Returns:
|
|
torch.Tensor: Encoded tensor T X B X C.
|
|
|
|
"""
|
|
x = x.transpose(0, 1)
|
|
self.extend_pe(x)
|
|
pos_emb = self.pe[
|
|
:,
|
|
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
|
|
]
|
|
pos_emb = pos_emb.transpose(0, 1)
|
|
return pos_emb
|
|
|
|
|
|
class ConformerRotaryPositionalEmbedding(nn.Module):
|
|
"""Rotary positional embedding
|
|
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
dim = config.hidden_size // config.num_attention_heads
|
|
base = config.rotary_embedding_base
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
self.cached_sequence_length = None
|
|
self.cached_rotary_positional_embedding = None
|
|
|
|
def forward(self, hidden_states):
|
|
sequence_length = hidden_states.shape[1]
|
|
|
|
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
|
|
return self.cached_rotary_positional_embedding
|
|
|
|
self.cached_sequence_length = sequence_length
|
|
|
|
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
|
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
|
embeddings = torch.cat((freqs, freqs), dim=-1)
|
|
|
|
cos_embeddings = embeddings.cos()[:, None, None, :]
|
|
sin_embeddings = embeddings.sin()[:, None, None, :]
|
|
|
|
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
|
|
return self.cached_rotary_positional_embedding
|
|
|
|
|
|
class ConformerInputFeatureProjection(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
subsample_embed_dim = config.hidden_size * (((config.input_dim - 1) // 2 - 1) // 2)
|
|
|
|
self.projection = nn.Linear(subsample_embed_dim, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
|
|
|
def forward(self, hidden_states):
|
|
"""
|
|
Args:
|
|
hidden_states: Input Tensor of shape T X B X C
|
|
Returns:
|
|
Tensor of shape T X B X C
|
|
"""
|
|
|
|
|
|
hidden_states = self.projection(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class ConformerFeedForward(nn.Module):
|
|
"""Positionwise feed forward layer used in conformer"""
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
|
|
|
self.intermediate_dense = nn.Linear(config.hidden_size, config.ffn_dim)
|
|
if isinstance(config.hidden_act, str):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
self.output_dense = nn.Linear(config.ffn_dim, config.hidden_size)
|
|
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
|
|
|
def forward(self, hidden_states):
|
|
"""
|
|
Args:
|
|
x: Input Tensor of shape T X B X C
|
|
Returns:
|
|
Tensor of shape T X B X C
|
|
"""
|
|
hidden_states = self.intermediate_dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
hidden_states = self.intermediate_dropout(hidden_states)
|
|
hidden_states = self.output_dense(hidden_states)
|
|
hidden_states = self.output_dropout(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class ConformerConvolutionModule(nn.Module):
|
|
"""Convolution block used in the conformer block"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
|
|
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
|
self.pointwise_conv1 = nn.Conv1d(
|
|
config.hidden_size,
|
|
2 * config.hidden_size,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False,
|
|
)
|
|
self.glu = nn.GLU(dim=1)
|
|
self.depthwise_conv = nn.Conv1d(
|
|
config.hidden_size,
|
|
config.hidden_size,
|
|
config.conv_depthwise_kernel_size,
|
|
stride=1,
|
|
padding=(config.conv_depthwise_kernel_size - 1) // 2,
|
|
groups=config.hidden_size,
|
|
bias=False,
|
|
)
|
|
self.batch_norm = nn.BatchNorm1d(config.hidden_size)
|
|
self.activation = ACT2FN[config.hidden_act]
|
|
self.pointwise_conv2 = nn.Conv1d(
|
|
config.hidden_size,
|
|
config.hidden_size,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False,
|
|
)
|
|
self.dropout = nn.Dropout(config.conformer_conv_dropout)
|
|
|
|
def forward(self, hidden_states):
|
|
"""
|
|
Args:
|
|
hidden_states: Input of shape B X T X C
|
|
Returns:
|
|
Tensor of shape B X T X C
|
|
"""
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states = hidden_states.transpose(1, 2)
|
|
|
|
|
|
|
|
hidden_states = self.pointwise_conv1(hidden_states)
|
|
|
|
hidden_states = self.glu(hidden_states)
|
|
|
|
|
|
hidden_states = self.depthwise_conv(hidden_states)
|
|
hidden_states = self.batch_norm(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
|
|
hidden_states = self.pointwise_conv2(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = hidden_states.transpose(1, 2)
|
|
return hidden_states
|
|
|
|
|
|
class ConformerSelfAttention(nn.Module):
|
|
"""ConformerSelfAttention object.
|
|
Can be enhanced with rotary or relative position embeddings.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.head_size = config.hidden_size // config.num_attention_heads
|
|
self.num_heads = config.num_attention_heads
|
|
self.position_embeddings_type = config.position_embeddings_type
|
|
|
|
self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
|
self.dropout = nn.Dropout(p=config.attention_dropout)
|
|
|
|
if self.position_embeddings_type == "relative":
|
|
|
|
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
|
|
|
|
|
self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
|
|
self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
|
|
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
|
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
relative_position_embeddings: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
hidden_states = hidden_states.transpose(0, 1)
|
|
relative_position_embeddings = relative_position_embeddings.transpose(0, 1)
|
|
batch_size, sequence_length, hidden_size = hidden_states.size()
|
|
|
|
|
|
query_key_states = hidden_states
|
|
value_states = hidden_states
|
|
|
|
if self.position_embeddings_type == "rotary":
|
|
if relative_position_embeddings is None:
|
|
raise ValueError(
|
|
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
|
|
)
|
|
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
|
|
|
|
|
|
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
|
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
|
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
|
|
|
|
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
if self.position_embeddings_type == "relative":
|
|
if relative_position_embeddings is None:
|
|
raise ValueError(
|
|
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
|
|
" 'relative'"
|
|
)
|
|
|
|
|
|
scores = self._apply_relative_embeddings(
|
|
query=query, key=key, relative_position_embeddings=relative_position_embeddings
|
|
)
|
|
else:
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
|
|
|
|
|
|
if attention_mask is not None:
|
|
scores = scores.masked_fill(
|
|
attention_mask.unsqueeze(1).unsqueeze(2).to(bool),
|
|
float("-inf"),
|
|
)
|
|
|
|
|
|
probs = torch.softmax(scores, dim=-1)
|
|
probs = self.dropout(probs)
|
|
|
|
|
|
hidden_states = torch.matmul(probs, value)
|
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
|
|
hidden_states = self.linear_out(hidden_states)
|
|
|
|
|
|
hidden_states = hidden_states.transpose(0, 1)
|
|
|
|
return hidden_states, probs
|
|
|
|
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
|
|
batch_size, sequence_length, hidden_size = hidden_states.size()
|
|
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
|
|
|
|
cos = relative_position_embeddings[0, :sequence_length, ...]
|
|
sin = relative_position_embeddings[1, :sequence_length, ...]
|
|
|
|
|
|
hidden_states = hidden_states.transpose(0, 1)
|
|
rotated_states_begin = hidden_states[..., : self.head_size // 2]
|
|
rotated_states_end = hidden_states[..., self.head_size // 2 :]
|
|
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
|
|
hidden_states = (hidden_states * cos) + (rotated_states * sin)
|
|
hidden_states = hidden_states.transpose(0, 1)
|
|
|
|
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
|
|
|
|
return hidden_states
|
|
|
|
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
|
|
|
|
|
|
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
|
|
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
|
|
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
|
|
)
|
|
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
|
|
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
|
|
|
|
|
|
|
|
query = query.transpose(1, 2)
|
|
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
|
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
|
|
|
|
|
|
|
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
|
|
|
|
|
|
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
|
|
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
|
|
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
|
|
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
|
|
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
|
|
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
|
|
|
|
|
|
|
|
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
|
|
|
|
return scores
|
|
|
|
|
|
class ConformerEncoderLayer(nn.Module):
|
|
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
embed_dim = config.hidden_size
|
|
dropout = config.attention_dropout
|
|
|
|
|
|
self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
|
|
self.ffn1 = ConformerFeedForward(config)
|
|
|
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
|
|
self.self_attn_dropout = nn.Dropout(dropout)
|
|
self.self_attn = ConformerSelfAttention(config)
|
|
|
|
|
|
self.conv_module = ConformerConvolutionModule(config)
|
|
|
|
|
|
self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
|
|
self.ffn2 = ConformerFeedForward(config)
|
|
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
relative_position_embeddings: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
):
|
|
hidden_states = hidden_states
|
|
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.ffn1_layer_norm(hidden_states)
|
|
hidden_states = self.ffn1(hidden_states)
|
|
hidden_states = hidden_states * 0.5 + residual
|
|
residual = hidden_states
|
|
|
|
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
hidden_states, attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
relative_position_embeddings=relative_position_embeddings,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = self.self_attn_dropout(hidden_states)
|
|
hidden_states = hidden_states + residual
|
|
|
|
|
|
residual = hidden_states
|
|
hidden_states = hidden_states.transpose(0, 1)
|
|
hidden_states = self.conv_module(hidden_states)
|
|
hidden_states = hidden_states.transpose(0, 1)
|
|
hidden_states = residual + hidden_states
|
|
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.ffn2_layer_norm(hidden_states)
|
|
hidden_states = self.ffn2(hidden_states)
|
|
hidden_states = hidden_states * 0.5 + residual
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
return hidden_states, attn_weights
|
|
|
|
|
|
class ConformerEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_scale = math.sqrt(config.hidden_size)
|
|
if config.no_scale_embedding:
|
|
self.embed_scale = 1.0
|
|
|
|
if config.position_embeddings_type == "relative":
|
|
self.embed_positions = ConformerRelPositionalEmbedding(config)
|
|
elif config.position_embeddings_type == "rotary":
|
|
self.embed_positions = ConformerRotaryPositionalEmbedding(config)
|
|
else:
|
|
self.embed_positions = None
|
|
|
|
self.input_projection = ConformerInputFeatureProjection(config)
|
|
|
|
self.layers = nn.ModuleList([ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
return_dict=True,
|
|
):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
hidden_states = self.embed_scale * hidden_states
|
|
|
|
if self.embed_positions is not None:
|
|
relative_position_embeddings = self.embed_positions(hidden_states)
|
|
else:
|
|
relative_position_embeddings = None
|
|
|
|
hidden_states = self.input_projection(hidden_states)
|
|
for i, layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),)
|
|
|
|
|
|
dropout_probability = torch.rand([])
|
|
|
|
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
|
if not skip_the_layer:
|
|
layer_outputs = layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
relative_position_embeddings=relative_position_embeddings,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if skip_the_layer:
|
|
layer_outputs = (None, None)
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
|
|
hidden_states = hidden_states.transpose(0, 1)
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|
|
|
|
|
|
class MeralionBestRqModel(PreTrainedModel):
|
|
config_class = MeralionBestRqConformerEncoderConfig
|
|
base_model_prefix = "bestrq_encoder"
|
|
|
|
def __init__(self, config: MeralionBestRqConformerEncoderConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.conv_subsample = Conv2dSubsampling(config)
|
|
|
|
self.encoder = ConformerEncoder(config)
|
|
|
|
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_values: Optional[torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
input_lengths: Optional[torch.Tensor] = None,
|
|
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
input_values = input_values.transpose(2, 1)
|
|
conv_outputs, output_lengths = self.conv_subsample(input_values, input_lengths)
|
|
x = conv_outputs.transpose(0, 1)
|
|
|
|
encoder_padding_mask = make_pad_mask(output_lengths, max_len=x.shape[0])
|
|
|
|
encoder_outputs = self.encoder(
|
|
x,
|
|
attention_mask=encoder_padding_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = encoder_outputs[0]
|
|
|
|
if not return_dict:
|
|
return (hidden_states, conv_outputs) + encoder_outputs[1:]
|
|
|
|
output = Wav2Vec2BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
extract_features=conv_outputs,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
output["output_lengths"] = output_lengths
|
|
return output
|
|
|
|
|
|
|
|
class MeralionBestRqModelForCTC(PreTrainedModel):
|
|
|
|
config_class = MeralionBestRqConformerEncoderConfig
|
|
base_model_prefix = "bestrq_encoder"
|
|
|
|
def __init__(self, config, target_lang: Optional[str] = None, **kwargs):
|
|
super().__init__(config)
|
|
|
|
self.bestrq_encoder = MeralionBestRqModel(config)
|
|
self.dropout = nn.Dropout(config.final_dropout)
|
|
|
|
self.target_lang = target_lang
|
|
|
|
if config.vocab_size is None:
|
|
raise ValueError(
|
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
|
"does not define the vocabulary size of the language model head. Please "
|
|
"instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
|
"or define `vocab_size` of your model's configuration."
|
|
)
|
|
output_hidden_size = (
|
|
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
|
)
|
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
def forward(
|
|
self,
|
|
input_values: Optional[torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
input_lengths: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
) -> Union[Tuple, CausalLMOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
|
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
|
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
|
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
|
config.vocab_size - 1]`.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if labels is not None and labels.max() >= self.config.vocab_size:
|
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
|
|
|
outputs = self.bestrq_encoder(
|
|
input_values,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
input_lengths=input_lengths
|
|
)
|
|
|
|
hidden_states = outputs.last_hidden_state
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
|
|
|
|
labels_mask = labels >= 0
|
|
target_lengths = labels_mask.sum(-1)
|
|
flattened_targets = labels.masked_select(labels_mask)
|
|
|
|
|
|
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
|
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
loss = nn.functional.ctc_loss(
|
|
log_probs,
|
|
flattened_targets,
|
|
outputs.output_lengths,
|
|
target_lengths,
|
|
blank=self.config.pad_token_id,
|
|
reduction=self.config.ctc_loss_reduction,
|
|
zero_infinity=self.config.ctc_zero_infinity,
|
|
)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return CausalLMOutput(
|
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
|
)
|
|
|