Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Based on fairseq code bases | |
# https://github.com/facebookresearch/fairseq | |
# -------------------------------------------------------- | |
""" | |
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/audio/multi_modality_dataset.py | |
""" | |
from typing import Optional | |
import numpy as np | |
import torch | |
from fairseq.data import ( | |
LanguagePairDataset, | |
) | |
from fairseq.data.audio.multi_modality_dataset import LangPairMaskDataset as FairseqLangPairMaskDataset | |
class LangPairMaskDataset(FairseqLangPairMaskDataset): | |
def __init__( | |
self, | |
dataset: LanguagePairDataset, | |
src_eos: int, | |
src_bos: Optional[int] = None, | |
noise_id: Optional[int] = -1, | |
mask_ratio: Optional[float] = 0, | |
mask_type: Optional[str] = "random", | |
): | |
super.__init__( | |
dataset, | |
src_eos, | |
src_bos, | |
noise_id, | |
mask_ratio, | |
mask_type, | |
) | |
def mask_src_tokens(self, sample): | |
src_item = sample["source"] | |
mask = None | |
if self.mask_type == "random": | |
mask = torch.rand(len(src_item)).le(self.mask_ratio) | |
else: | |
mask = torch.ones(len(src_item)) | |
mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0 | |
mask = mask.eq(1) | |
if src_item[0] == self.src_bos: | |
mask[0] = False | |
if src_item[-1] == self.src_eos: | |
mask[-1] = False | |
mask_src_item = src_item.masked_fill(mask, self.noise_id) | |
smp = sample | |
smp["source"] = mask_src_item | |
return smp | |
def collater(self, samples, pad_to_length=None): | |
return self.dataset.collater(samples, pad_to_length=pad_to_length) | |