Serhiy Stetskovych
Initial commit
2ccf6b5
raw
history blame
7.75 kB
from typing import Tuple
import numpy as np
import torch
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F
from einops import rearrange, repeat
from beartype import beartype
from beartype.typing import Optional
def exists(val):
return val is not None
class AlignerNet(Module):
"""alignment model https://arxiv.org/pdf/2108.10447.pdf """
def __init__(
self,
dim_in=80,
dim_hidden=512,
attn_channels=80,
temperature=0.0005,
):
super().__init__()
self.temperature = temperature
self.key_layers = nn.ModuleList([
nn.Conv1d(
dim_hidden,
dim_hidden * 2,
kernel_size=3,
padding=1,
bias=True,
),
nn.ReLU(inplace=True),
nn.Conv1d(dim_hidden * 2, attn_channels, kernel_size=1, padding=0, bias=True)
])
self.query_layers = nn.ModuleList([
nn.Conv1d(
dim_in,
dim_in * 2,
kernel_size=3,
padding=1,
bias=True,
),
nn.ReLU(inplace=True),
nn.Conv1d(dim_in * 2, dim_in, kernel_size=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv1d(dim_in, attn_channels, kernel_size=1, padding=0, bias=True)
])
@beartype
def forward(
self,
queries: Tensor,
keys: Tensor,
mask: Optional[Tensor] = None
):
key_out = keys
for layer in self.key_layers:
key_out = layer(key_out)
query_out = queries
for layer in self.query_layers:
query_out = layer(query_out)
key_out = rearrange(key_out, 'b c t -> b t c')
query_out = rearrange(query_out, 'b c t -> b t c')
attn_logp = torch.cdist(query_out, key_out)
attn_logp = rearrange(attn_logp, 'b ... -> b 1 ...')
if exists(mask):
mask = rearrange(mask.bool(), '... c -> ... 1 c')
attn_logp.data.masked_fill_(~mask, -torch.finfo(attn_logp.dtype).max)
attn = attn_logp.softmax(dim = -1)
return attn, attn_logp
def pad_tensor(input, pad, value=0):
pad = [item for sublist in reversed(pad) for item in sublist] # Flatten the tuple
assert len(pad) // 2 == len(input.shape), 'Padding dimensions do not match input dimensions'
return F.pad(input, pad, mode='constant', value=value)
def maximum_path(value, mask, const=None):
device = value.device
dtype = value.dtype
if not exists(const):
const = torch.tensor(float('-inf')).to(device) # Patch for Sphinx complaint
value = value * mask
b, t_x, t_y = value.shape
direction = torch.zeros(value.shape, dtype=torch.int64, device=device)
v = torch.zeros((b, t_x), dtype=torch.float32, device=device)
x_range = torch.arange(t_x, dtype=torch.float32, device=device).view(1, -1)
for j in range(t_y):
v0 = pad_tensor(v, ((0, 0), (1, 0)), value = const)[:, :-1]
v1 = v
max_mask = v1 >= v0
v_max = torch.where(max_mask, v1, v0)
direction[:, :, j] = max_mask
index_mask = x_range <= j
v = torch.where(index_mask.view(1,-1), v_max + value[:, :, j], const)
direction = torch.where(mask.bool(), direction, 1)
path = torch.zeros(value.shape, dtype=torch.float32, device=device)
index = mask[:, :, 0].sum(1).long() - 1
index_range = torch.arange(b, device=device)
for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1
path = path * mask.float()
path = path.to(dtype=dtype)
return path
class ForwardSumLoss(Module):
def __init__(
self,
blank_logprob = -1
):
super().__init__()
self.blank_logprob = blank_logprob
self.ctc_loss = torch.nn.CTCLoss(
blank = 0, # check this value
zero_infinity = True
)
def forward(self, attn_logprob, key_lens, query_lens):
device, blank_logprob = attn_logprob.device, self.blank_logprob
max_key_len = attn_logprob.size(-1)
# Reorder input to [query_len, batch_size, key_len]
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')
# Add blank label
attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), value = blank_logprob)
# Convert to log probabilities
# Note: Mask out probs beyond key_len
mask_value = -torch.finfo(attn_logprob.dtype).max
attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)
attn_logprob = attn_logprob.log_softmax(dim = -1)
# Target sequences
target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long)
target_seqs = repeat(target_seqs, 'n -> b n', b = key_lens.numel())
# Evaluate CTC loss
cost = self.ctc_loss(attn_logprob, target_seqs, query_lens, key_lens)
return cost
class BinLoss(Module):
def forward(self, attn_hard, attn_logprob, key_lens):
batch, device = attn_logprob.shape[0], attn_logprob.device
max_key_len = attn_logprob.size(-1)
# Reorder input to [query_len, batch_size, key_len]
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')
attn_hard = rearrange(attn_hard, 'b t c -> c b t')
mask_value = -torch.finfo(attn_logprob.dtype).max
attn_logprob.masked_fill_(torch.arange(max_key_len, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)
attn_logprob = attn_logprob.log_softmax(dim = -1)
return (attn_hard * attn_logprob).sum() / batch
class Aligner(Module):
def __init__(
self,
dim_in,
dim_hidden,
attn_channels=80,
temperature=0.0005
):
super().__init__()
self.dim_in = dim_in
self.dim_hidden = dim_hidden
self.attn_channels = attn_channels
self.temperature = temperature
self.aligner = AlignerNet(
dim_in = self.dim_in,
dim_hidden = self.dim_hidden,
attn_channels = self.attn_channels,
temperature = self.temperature
)
def forward(
self,
x,
x_mask,
y,
y_mask
):
alignment_soft, alignment_logprob = self.aligner(y, rearrange(x, 'b d t -> b t d'), x_mask)
x_mask = rearrange(x_mask, '... i -> ... i 1')
y_mask = rearrange(y_mask, '... j -> ... 1 j')
attn_mask = x_mask * y_mask
attn_mask = rearrange(attn_mask, 'b 1 i j -> b i j')
alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c')
alignment_mask = maximum_path(alignment_soft, attn_mask)
alignment_hard = torch.sum(alignment_mask, -1).int()
return alignment_hard, alignment_soft, alignment_logprob, alignment_mask
if __name__ == '__main__':
batch_size = 10
seq_len_y = 200 # length of sequence y
seq_len_x = 35
feature_dim = 80 # feature dimension
x = torch.randn(batch_size, 512, seq_len_x)
x = x.transpose(1,2) #dim-1 is the channels for conv
y = torch.randn(batch_size, seq_len_y, feature_dim)
y = y.transpose(1,2) #dim-1 is the channels for conv
# Create masks
x_mask = torch.ones(batch_size, 1, seq_len_x)
y_mask = torch.ones(batch_size, 1, seq_len_y)
align = Aligner(dim_in = 80, dim_hidden=512, attn_channels=80)
alignment_hard, alignment_soft, alignment_logprob, alignment_mas = align(x, x_mask, y, y_mask)