############################# # Imports ############################# # Python modules from typing import Optional, Tuple # Remote modules import torch from torch import nn from transformers import BartConfig from transformers.activations import ACT2FN # Local modules from .bart_attention import BartCustomAttention from .bart_mask_attention import BartCustomMaskAttention from .config import BartCustomConfig class BartCustomEncoderLayer(nn.Module): def __init__(self, config: BartCustomConfig, heads_mask: Optional[torch.Tensor]): super().__init__() self.embed_dim = config.d_model is_simple_mask_commonsense = config.is_simple_mask_commonsense if not is_simple_mask_commonsense: print("Selecting complex relation attention") self.self_attn = BartCustomAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, num_relation_kinds=config.num_relation_kinds, use_same_relation_kv_emb=config.use_same_relation_kv_emb, heads_mask=heads_mask, ) else: print("Selecting simple (MASK) relation attention") self.self_attn = BartCustomMaskAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, num_relation_kinds=config.num_relation_kinds, heads_mask=heads_mask, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.FloatTensor, attention_mask: torch.FloatTensor, layer_head_mask: torch.FloatTensor, output_attentions: Optional[bool] = False, relation_inputs: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, relation_inputs=relation_inputs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs