|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from easydict import EasyDict as edict |
|
from xml.model_components import BertAttention, TrainablePositionalEncoding |
|
|
|
|
|
class TextEncoder(nn.Module): |
|
def __init__(self, hidden_size, drop, input_drop, nheads, max_position_embeddings): |
|
super().__init__() |
|
self.transformer_encoder = BertAttention(edict( |
|
hidden_size=hidden_size, |
|
intermediate_size=hidden_size, |
|
hidden_dropout_prob=drop, |
|
attention_probs_dropout_prob=drop, |
|
num_attention_heads=nheads, |
|
)) |
|
self.pos_embed = TrainablePositionalEncoding( |
|
max_position_embeddings=max_position_embeddings, |
|
hidden_size=hidden_size, |
|
dropout=input_drop, |
|
) |
|
self.modular_vector_mapping = nn.Linear(hidden_size, 1, bias=False) |
|
|
|
def forward(self, feat, mask): |
|
""" |
|
Args: |
|
feat: (N, L, D=hidden_size) |
|
mask: (N, L) with 1 indicates valid |
|
|
|
Returns: |
|
(N, D) |
|
""" |
|
feat = self.pos_embed(feat) |
|
feat = self.transformer_encoder(feat, mask.unsqueeze(1)) |
|
att_scores = self.modular_vector_mapping(feat) |
|
att_scores = F.softmax(mask_logits(att_scores, mask.unsqueeze(2)), dim=1) |
|
pooled_feat = torch.einsum("blm,bld->bmd", att_scores, feat) |
|
return pooled_feat.squeeze(1) |
|
|
|
|
|
def mask_logits(target, mask): |
|
return target * mask + (1 - mask) * (-1e10) |
|
|
|
|
|
def build_text_encoder(args): |
|
return TextEncoder( |
|
hidden_size=args.hidden_dim, |
|
drop=args.dropout, |
|
input_drop=args.input_dropout, |
|
nheads=args.nheads, |
|
max_position_embeddings=args.max_q_l |
|
) |
|
|