Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # coding=utf-8 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model.module.biaffine import Biaffine | |
| class AnchorClassifier(nn.Module): | |
| def __init__(self, dataset, args, initialize: bool, bias=True, mode="anchor"): | |
| super(AnchorClassifier, self).__init__() | |
| self.token_f = nn.Linear(args.hidden_size, args.hidden_size_anchor) | |
| self.label_f = nn.Linear(args.hidden_size, args.hidden_size_anchor) | |
| self.dropout = nn.Dropout(args.dropout_anchor) | |
| if bias and initialize: | |
| bias_init = torch.tensor([getattr(dataset, f"{mode}_freq")]) | |
| bias_init = (bias_init / (1.0 - bias_init)).log() | |
| else: | |
| bias_init = None | |
| self.output = Biaffine(args.hidden_size_anchor, 1, bias=bias, bias_init=bias_init) | |
| def forward(self, label, tokens, encoder_mask): | |
| tokens = self.dropout(F.elu(self.token_f(tokens))) # shape: (B, T_w, H) | |
| label = self.dropout(F.elu(self.label_f(label))) # shape: (B, T_l, H) | |
| anchor = self.output(label, tokens).squeeze(-1) # shape: (B, T_l, T_w) | |
| anchor = anchor.masked_fill(encoder_mask.unsqueeze(1), float("-inf")) # shape: (B, T_l, T_w) | |
| return anchor | |