Spaces:
Runtime error
Runtime error
# ---------------------------------------------------------------------------- | |
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM | |
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 | |
# | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# ---------------------------------------------------------------------------- | |
""" | |
Modified form: https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/sequence_generator.py | |
""" | |
import torch | |
import numpy as np | |
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig | |
from fairseq.speech_generator import SpeechGenerator | |
class NonAutoregressiveUnitGenerator(SpeechGenerator): | |
def generate(self, model, sample, has_targ=False, **kwargs): | |
model.eval() | |
bsz, max_src_len = sample["net_input"]["src_tokens"].size() | |
n_frames_per_step = model.encoder.n_frames_per_step | |
out_dim = model.encoder.out_dim | |
raw_dim = out_dim // n_frames_per_step | |
logit, logit_post, out_lens, log_dur_out, _, _ = model( | |
src_tokens=sample["net_input"]["src_tokens"], | |
src_lengths=sample["net_input"]["src_lengths"], | |
speaker=sample["speaker"], | |
durations=sample["durations"], | |
pitches=sample["pitches"], | |
energies=sample["energies"], | |
) | |
if logit_post is not None: | |
logit = logit_post | |
logit = logit.view(bsz, -1, raw_dim) | |
pred = logit.argmax(dim=-1) | |
## get duration prediction | |
src_tokens = sample["net_input"]["src_tokens"] | |
src_lengths = sample["net_input"]["src_lengths"] | |
padding_mask = src_tokens.eq(model.encoder.padding_idx) | |
d_factor = 1.0 ## set by model | |
dur_out = torch.clamp( | |
torch.round((torch.exp(log_dur_out) - 1) * d_factor).long(), min=0 | |
) | |
dur_out.masked_fill_(padding_mask, 0) | |
x = src_tokens.unsqueeze(-1) | |
x, src_out_lens = model.encoder.var_adaptor.length_regulator(x, dur_out) | |
fa_src_tokens = x.view(bsz, -1) | |
finalized = [ | |
{ | |
"unit": pred[b, :l], | |
"fa_src": fa_src_tokens[b, :l], | |
"duration": dur_out[b, :L], | |
} | |
for b, l, L in zip(range(bsz), out_lens, src_lengths) | |
] | |
return finalized | |