amupd's picture
SpeechT5 upload
62e9ca6
raw
history blame
2.08 kB
# --------------------------------------------------------
# The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/audio/multi_modality_dataset.py
"""
from typing import Optional
import numpy as np
import torch
from fairseq.data import (
LanguagePairDataset,
)
from fairseq.data.audio.multi_modality_dataset import LangPairMaskDataset as FairseqLangPairMaskDataset
class LangPairMaskDataset(FairseqLangPairMaskDataset):
def __init__(
self,
dataset: LanguagePairDataset,
src_eos: int,
src_bos: Optional[int] = None,
noise_id: Optional[int] = -1,
mask_ratio: Optional[float] = 0,
mask_type: Optional[str] = "random",
):
super.__init__(
dataset,
src_eos,
src_bos,
noise_id,
mask_ratio,
mask_type,
)
def mask_src_tokens(self, sample):
src_item = sample["source"]
mask = None
if self.mask_type == "random":
mask = torch.rand(len(src_item)).le(self.mask_ratio)
else:
mask = torch.ones(len(src_item))
mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
mask = mask.eq(1)
if src_item[0] == self.src_bos:
mask[0] = False
if src_item[-1] == self.src_eos:
mask[-1] = False
mask_src_item = src_item.masked_fill(mask, self.noise_id)
smp = sample
smp["source"] = mask_src_item
return smp
def collater(self, samples, pad_to_length=None):
return self.dataset.collater(samples, pad_to_length=pad_to_length)