Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Based on fairseq code bases | |
# https://github.com/pytorch/fairseq | |
# -------------------------------------------------------- | |
import torch | |
class RelativePositionalEncoding(torch.nn.Module): | |
def __init__(self, d_model, maxlen=1000, embed_v=False): | |
super(RelativePositionalEncoding, self).__init__() | |
self.d_model = d_model | |
self.maxlen = maxlen | |
self.pe_k = torch.nn.Embedding(2*maxlen, d_model) | |
if embed_v: | |
self.pe_v = torch.nn.Embedding(2*maxlen, d_model) | |
self.embed_v = embed_v | |
def forward(self, pos_seq, incremental_state=None): | |
pos_seq[pos_seq < -self.maxlen] = -self.maxlen | |
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1 | |
pos_seq = pos_seq + self.maxlen | |
if incremental_state is not None: | |
pos_seq = pos_seq[-1:] | |
if self.embed_v: | |
return self.pe_k(pos_seq), self.pe_v(pos_seq) | |
else: | |
return self.pe_k(pos_seq), None | |