|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class Model(nn.Module): |
|
|
|
def __init__(self, dataset_max_length: int, null_label: int): |
|
super().__init__() |
|
self.max_length = dataset_max_length + 1 |
|
self.null_label = null_label |
|
|
|
def _get_length(self, logit, dim=-1): |
|
""" Greed decoder to obtain length from logit""" |
|
out = (logit.argmax(dim=-1) == self.null_label) |
|
abn = out.any(dim) |
|
out = ((out.cumsum(dim) == 1) & out).max(dim)[1] |
|
out = out + 1 |
|
out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device)) |
|
return out |
|
|
|
@staticmethod |
|
def _get_padding_mask(length, max_length): |
|
length = length.unsqueeze(-1) |
|
grid = torch.arange(0, max_length, device=length.device).unsqueeze(0) |
|
return grid >= length |
|
|
|
@staticmethod |
|
def _get_location_mask(sz, device=None): |
|
mask = torch.eye(sz, device=device) |
|
mask = mask.float().masked_fill(mask == 1, float('-inf')) |
|
return mask |
|
|