ynhe
init
16dc4f2
raw
history blame
1.78 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from easydict import EasyDict as edict
from xml.model_components import BertAttention, TrainablePositionalEncoding
class TextEncoder(nn.Module):
def __init__(self, hidden_size, drop, input_drop, nheads, max_position_embeddings):
super().__init__()
self.transformer_encoder = BertAttention(edict(
hidden_size=hidden_size,
intermediate_size=hidden_size,
hidden_dropout_prob=drop,
attention_probs_dropout_prob=drop,
num_attention_heads=nheads,
))
self.pos_embed = TrainablePositionalEncoding(
max_position_embeddings=max_position_embeddings,
hidden_size=hidden_size,
dropout=input_drop,
)
self.modular_vector_mapping = nn.Linear(hidden_size, 1, bias=False)
def forward(self, feat, mask):
"""
Args:
feat: (N, L, D=hidden_size)
mask: (N, L) with 1 indicates valid
Returns:
(N, D)
"""
feat = self.pos_embed(feat) # (N, L, D)
feat = self.transformer_encoder(feat, mask.unsqueeze(1))
att_scores = self.modular_vector_mapping(feat) # (N, L, 1)
att_scores = F.softmax(mask_logits(att_scores, mask.unsqueeze(2)), dim=1)
pooled_feat = torch.einsum("blm,bld->bmd", att_scores, feat) # (N, 2 or 1, D)
return pooled_feat.squeeze(1)
def mask_logits(target, mask):
return target * mask + (1 - mask) * (-1e10)
def build_text_encoder(args):
return TextEncoder(
hidden_size=args.hidden_dim,
drop=args.dropout,
input_drop=args.input_dropout,
nheads=args.nheads,
max_position_embeddings=args.max_q_l
)