amupd's picture
SpeechT5 upload
62e9ca6
raw
history blame
2.59 kB
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
"""
Modified form: https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/sequence_generator.py
"""
import torch
import numpy as np
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.speech_generator import SpeechGenerator
class NonAutoregressiveUnitGenerator(SpeechGenerator):
@torch.no_grad()
def generate(self, model, sample, has_targ=False, **kwargs):
model.eval()
bsz, max_src_len = sample["net_input"]["src_tokens"].size()
n_frames_per_step = model.encoder.n_frames_per_step
out_dim = model.encoder.out_dim
raw_dim = out_dim // n_frames_per_step
logit, logit_post, out_lens, log_dur_out, _, _ = model(
src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
speaker=sample["speaker"],
durations=sample["durations"],
pitches=sample["pitches"],
energies=sample["energies"],
)
if logit_post is not None:
logit = logit_post
logit = logit.view(bsz, -1, raw_dim)
pred = logit.argmax(dim=-1)
## get duration prediction
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
padding_mask = src_tokens.eq(model.encoder.padding_idx)
d_factor = 1.0 ## set by model
dur_out = torch.clamp(
torch.round((torch.exp(log_dur_out) - 1) * d_factor).long(), min=0
)
dur_out.masked_fill_(padding_mask, 0)
x = src_tokens.unsqueeze(-1)
x, src_out_lens = model.encoder.var_adaptor.length_regulator(x, dur_out)
fa_src_tokens = x.view(bsz, -1)
finalized = [
{
"unit": pred[b, :l],
"fa_src": fa_src_tokens[b, :l],
"duration": dur_out[b, :L],
}
for b, l, L in zip(range(bsz), out_lens, src_lengths)
]
return finalized