amupd's picture
SpeechT5 upload
62e9ca6
raw
history blame
8.6 kB
# --------------------------------------------------------
# The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import itertools
import logging
import os
import sys
from typing import Any, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import data_utils, Dictionary
from fairseq.data.audio.hubert_dataset import HubertDataset
logger = logging.getLogger(__name__)
class Speech2cDataset(HubertDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
label_paths: List[str],
label_rates: Union[List[float], float], # -1 for sequence labels
pad_list: List[str],
eos_list: List[str],
label_processors: Optional[List[Any]] = None,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
max_sample_size: Optional[int] = None,
shuffle: bool = True,
pad_audio: bool = False,
normalize: bool = False,
store_labels: bool = True,
random_crop: bool = False,
single_target: bool = False,
tgt_dict: Optional[Dictionary] = None,
add_decoder: bool = False,
fine_tuning: bool = False,
tokenizer = None,
tgt_lang_idx: int = None,
mbart_style_lang_id: bool = False,
retry_times: int = 5,
reduce_label_for_dec: bool = True,
):
super().__init__(
manifest_path,
sample_rate,
label_paths,
label_rates,
pad_list,
eos_list,
label_processors,
max_keep_sample_size,
min_keep_sample_size,
max_sample_size,
shuffle,
pad_audio,
normalize,
store_labels,
random_crop,
single_target
)
self.tgt_dict = tgt_dict
self.add_decoder = add_decoder
self.fine_tuning = fine_tuning
self.tokenizer = tokenizer
self.tgt_lang_idx = tgt_lang_idx
self.mbart_style_lang_id = mbart_style_lang_id
self.retry_times = retry_times
self.reduce_label_for_dec = reduce_label_for_dec
logger.info(
f"tgt_lang_idx={self.tgt_lang_idx}, reduce_label_for_dec={reduce_label_for_dec}, "
f"mbart_style_lang_id={mbart_style_lang_id}"
)
self.sizes = np.array(self.sizes)
def get_label(self, index, label_idx):
if self.store_labels:
label = self.label_list[label_idx][index]
else:
with open(self.label_paths[label_idx]) as f:
offset_s, offset_e = self.label_offsets_list[label_idx][index]
f.seek(offset_s)
label = f.read(offset_e - offset_s)
if self.tokenizer is not None and self.fine_tuning:
label = self.tokenizer.encode(label)
if self.label_processors is not None:
label = self.label_processors[label_idx](label)
return label
def collater(self, samples):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
if self.pad_audio:
audio_size = min(max(audio_sizes), self.max_sample_size)
else:
audio_size = min(min(audio_sizes), self.max_sample_size)
collated_audios, padding_mask, audio_starts = self.collater_audio(
audios, audio_size
)
targets_by_label = [
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts
)
if self.add_decoder:
if self.fine_tuning:
decoder_label = [
torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
for i in range(targets_list[0].size(0))
]
else:
if self.tokenizer is not None:
decoder_label = [
# Set 48 for translate int to char and avoid \n
torch.cat(
(
torch.tensor(
self.tokenizer.sp.Encode(
"".join(
[chr(j + 48) for j in (
targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]]
).tolist()]
), out_type=int
)
),
torch.tensor([self.tgt_dict.eos()])
), dim=0
).long()
for i in range(targets_list[0].size(0))
]
else:
decoder_label = [
torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
for i in range(targets_list[0].size(0))
]
if self.mbart_style_lang_id:
decoder_label = [
torch.cat((decoder_label[i], torch.tensor([self.tgt_lang_idx])), 0).long()
for i in range(targets_list[0].size(0))
]
dec_ntokens = sum(x.size(0) for x in decoder_label)
decoder_target = data_utils.collate_tokens(
decoder_label,
self.tgt_dict.pad(),
self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
left_pad=False,
move_eos_to_beginning=False,
)
decoder_target_lengths = torch.tensor(
[x.size(0) for x in decoder_label], dtype=torch.long
)
prev_output_tokens = data_utils.collate_tokens(
decoder_label,
self.tgt_dict.pad(),
self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
left_pad=False,
move_eos_to_beginning=True,
)
if self.tgt_lang_idx is not None and not self.mbart_style_lang_id:
assert (prev_output_tokens[:, 0] != self.tgt_dict.eos()).sum() == 0
prev_output_tokens[:, 0] = self.tgt_lang_idx
net_input = {
"source": collated_audios,
"padding_mask": padding_mask,
"prev_output_tokens": prev_output_tokens,
}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
"decoder_target": decoder_target,
"decoder_target_lengths": decoder_target_lengths,
"dec_ntokens": dec_ntokens,
"lang_idx": self.tgt_lang_idx,
}
else:
net_input = {"source": collated_audios, "padding_mask": padding_mask}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
}
if self.single_target:
batch["target_lengths"] = lengths_list[0]
batch["ntokens"] = ntokens_list[0]
batch["target"] = targets_list[0]
else:
batch["target_lengths_list"] = lengths_list
batch["ntokens_list"] = ntokens_list
batch["target_list"] = targets_list
return batch
# @property
# def sizes(self):
# return np.array(self.sizes)