Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import logging | |
import os.path as op | |
from argparse import Namespace | |
from collections import OrderedDict | |
import torch | |
from fairseq.data import ( | |
Dictionary, | |
encoders, | |
PrependTokenDataset, | |
AppendTokenDataset, | |
data_utils, | |
StripTokenDataset, | |
TokenBlockDataset, | |
) | |
from fairseq.data.encoders.utils import get_whole_word_mask | |
from fairseq import utils | |
from artst.data.multitask_dataset import MultitaskDataset | |
from artst.data.speech_to_text_dataset import SpeechToTextDataset | |
from artst.data.text_to_speech_dataset import TextToSpeechDataset | |
from artst.data.speech_to_speech_dataset import SpeechToSpeechDataset | |
from artst.data.speech_to_class_dataset import SpeechToClassDataset | |
from artst.data.speech_dataset import SpeechPretrainDataset | |
from artst.data.text_dataset import TextPretrainDataset | |
from fairseq.data.shorten_dataset import maybe_shorten_dataset | |
from fairseq.tasks import LegacyFairseqTask, register_task | |
from fairseq.tasks.hubert_pretraining import LabelEncoder | |
logger = logging.getLogger(__name__) | |
TASK_NAME = ["s2t", "t2s", "s2s", "s2c", "pretrain"] | |
class ArTSTTask(LegacyFairseqTask): | |
def add_args(parser): | |
parser.add_argument("data", help="manifest root path") | |
parser.add_argument( | |
"--config-yaml", | |
type=str, | |
default="config.yaml", | |
help="Configuration YAML filename (under manifest root)", | |
) | |
parser.add_argument( | |
"--max-speech-sample-size", | |
default=None, | |
type=int, | |
metavar="N", | |
help="max speech sample size", | |
) | |
parser.add_argument( | |
"--min-speech-sample-size", | |
default=None, | |
type=int, | |
metavar="N", | |
help="min speech sample size", | |
) | |
parser.add_argument( | |
"--max-speech-positions", | |
default=4000, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the source sequence", | |
) | |
parser.add_argument( | |
"--max-text-positions", | |
default=450, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the target sequence", | |
) | |
parser.add_argument( | |
'--t5-task', | |
choices=TASK_NAME, | |
help='task for training' | |
) | |
parser.add_argument( | |
"--bpe-tokenizer", | |
type=str, | |
default=None, | |
help="bpe tokenizer for s2t", | |
) | |
# Speaker Identification (SID) | |
parser.add_argument( | |
"--finetune-from-modules", | |
default=None, | |
# choices=[ | |
# "encoder-decoder", "encoder", "decoder", | |
# "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-text_decoder_postnet", # ASR, T5 SID | |
# "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-speaker_decoder_postnet", # SID | |
# "speech_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # VC, SE | |
# "text_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # TTS | |
# ], | |
help="If set, using part modules of finetune model.", | |
) | |
parser.add_argument( | |
"--finetune-out-of-modules", | |
default=None, | |
# choices=[ | |
# "speaker_decoder_postnet", # SID | |
# "speech_decoder_postnet", # SE with reduction factor 1 | |
# ], | |
help="If set, remove part modules of finetune model.", | |
) | |
# BART | |
parser.add_argument( | |
"--shorten-method", | |
default="none", | |
choices=["none", "truncate", "random_crop"], | |
help="if not none, shorten sequences that exceed --tokens-per-sample", | |
) | |
parser.add_argument( | |
"--shorten-data-split-list", | |
default="", | |
help="comma-separated list of dataset splits to apply shortening to, " | |
'e.g., "train,valid" (default: all dataset splits)', | |
) | |
parser.add_argument( | |
"--tokens-per-sample", | |
default=512, | |
type=int, | |
help="max number of total tokens over all segments" | |
" per sample for dataset", | |
) | |
parser.add_argument( | |
"--sample-break-mode", | |
default="eos", | |
type=str, | |
help="mode for breaking sentence", | |
) | |
parser.add_argument( | |
"--mask", | |
default=0.3, | |
type=float, | |
help="fraction of words/subwords that will be masked", | |
) | |
parser.add_argument( | |
"--mask-random", | |
default=0.1, | |
type=float, | |
help="instead of using [MASK], use random token this often", | |
) | |
parser.add_argument( | |
"--insert", | |
default=0.0, | |
type=float, | |
help="insert this percentage of additional random tokens", | |
) | |
parser.add_argument( | |
"--permute", | |
default=0.0, | |
type=float, | |
help="take this proportion of subwords and permute them", | |
) | |
parser.add_argument( | |
"--rotate", | |
default=0.0, | |
type=float, | |
help="rotate this proportion of inputs", | |
) | |
parser.add_argument( | |
"--poisson-lambda", | |
default=3.5, | |
type=float, | |
help="randomly shuffle sentences for this proportion of inputs", | |
) | |
parser.add_argument( | |
"--permute-sentences", | |
default=0.0, | |
type=float, | |
help="shuffle this proportion of sentences in all inputs", | |
) | |
# parser.add_argument( | |
# "--mask-length", | |
# default="span-poisson", | |
# type=str, | |
# choices=["subword", "word", "span-poisson"], | |
# help="mask length to choose", | |
# ) | |
parser.add_argument( | |
"--replace-length", | |
default=1, | |
type=int, | |
help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)", | |
) | |
parser.add_argument( | |
"--iid-noise-target", | |
action="store_true", | |
help="whether to use t5 form target", | |
) | |
# Hubert | |
parser.add_argument( | |
"--hubert-labels", | |
nargs="*", | |
type=str, | |
default=['km'], | |
help="extension of the label files to load, frame-level labels for pre-training, and sequence-level label for fine-tuning", | |
) | |
parser.add_argument( | |
"--hubert-label-dir", | |
type=str, | |
default=None, | |
help="if set, looks for labels in this directory instead", | |
) | |
parser.add_argument( | |
"--sample-rate", | |
default=100, | |
type=float, | |
help="target sample rate. audio files will be up/down sampled to this rate", | |
) | |
parser.add_argument( | |
"--label-rates", | |
default=-1, | |
type=float, | |
help="if set, looks for labels in this directory instead", | |
) | |
parser.add_argument( | |
"--normalize", | |
action="store_true", | |
help="if set, normalizes input to have 0 mean and unit variance", | |
) | |
parser.add_argument( | |
"--enable-padding", | |
action="store_true", | |
help="pad shorter samples instead of cropping", | |
) | |
parser.add_argument( | |
"--pad-audio", | |
action="store_true", | |
help="pad audio to the longest one in the batch if true", | |
) | |
parser.add_argument( | |
"--random-crop", | |
action="store_true", | |
help="always crop from the beginning if false", | |
) | |
parser.add_argument( | |
"--single-target", | |
action="store_true", | |
help="if set, AddTargetDatasets outputs same keys " | |
"as AddTargetDataset", | |
) | |
parser.add_argument( | |
"--batch-ratio", | |
default=None, | |
type=str, | |
help="ratio of bach size for each dataset", | |
) | |
parser.add_argument( | |
"--sample-ratios", | |
default=None, | |
type=str, | |
help="ratio of sample for each dataset", | |
) | |
parser.add_argument( | |
"--ctc-weight", | |
type=float, | |
default=0.0, | |
help="ctc weight for inference", | |
) | |
parser.add_argument( | |
"--inference-speech", | |
type=bool, | |
default=False, | |
help="inference for TTS", | |
) | |
def __init__(self, args, dicts, config): | |
super().__init__(args) | |
self.dicts = dicts | |
self.config = config | |
self.t5_task = args.t5_task | |
# Used for filter size | |
if self.t5_task in ['s2t', 't2s', 's2s', 's2c']: | |
self.max_pos = [self.args.max_speech_positions * 256] | |
elif self.t5_task == 'pretrain': | |
self.max_pos = [self.args.max_speech_positions * 256, self.args.max_text_positions] | |
self.mask_idx = self.dicts["text"].add_symbol("<mask>") | |
# add blank token for ctc | |
# if args.ctc_weight > 0: | |
self.blank_symbol_idx = self.dicts["text"].add_symbol("<ctc_blank>") | |
self.blank_symbol = "<ctc_blank>" | |
# add mask token | |
if hasattr(args, "iid_noise_target") and args.iid_noise_target: | |
self.uni_mask_idxs = [] | |
for i in range(600): | |
self.uni_mask_idxs.append(self.dicts["text"].add_symbol("<mask>" + str(i))) | |
self.uni_mask_idxs = torch.tensor(self.uni_mask_idxs) | |
self.seed = args.seed | |
def setup_task(cls, args, **kwargs): | |
# load dictionaries and config | |
dicts = OrderedDict() | |
if args.t5_task == 'pretrain' and not hasattr(args, "shuffle_instance"): | |
args.shuffle_instance = False | |
# Prepare config | |
config = None | |
logger.info('No config file for ' + args.t5_task) | |
if args.t5_task == "pretrain": | |
dicts["hubert"] = [Dictionary.load(f"{args.hubert_label_dir}/dict.{label}.txt") for label in args.hubert_labels] | |
dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt")) | |
else: | |
if config is None: | |
dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt")) | |
else: | |
dicts["text"] = Dictionary.load(op.join(args.data, config.vocab_filename)) | |
return cls(args, dicts, config) | |
def build_criterion(self, args): | |
from fairseq import criterions | |
return criterions.build_criterion(args, self) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
sample_ratios = [] | |
if self.t5_task == "s2t": | |
## For speech to text task | |
bpe_tokenizer = self.build_bpe(self.args) | |
manifest = f"{self.args.data}/{split}.tsv" | |
procs = [LabelEncoder(self.dicts["text"])] | |
paths = [f"{self.args.hubert_label_dir}/{split}.txt"] | |
# Hawau: view dataset... | |
logger.info(f"Manifest: {manifest}") | |
# logger.info(f"Paths: {paths}") | |
self.datasets[split] = SpeechToTextDataset( | |
manifest, | |
sample_rate=self.args.sample_rate, | |
label_paths=paths, | |
label_processors=procs, | |
max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size, | |
min_keep_sample_size=self.args.min_speech_sample_size, | |
normalize=self.args.normalize, | |
store_labels=False, | |
tgt_dict=self.dicts["text"], | |
tokenizer=bpe_tokenizer, | |
) | |
elif self.t5_task == "t2s": | |
## For text to speech task | |
from fairseq.data import ConcatDataset | |
bpe_tokenizer = self.build_bpe(self.args) | |
procs = [LabelEncoder(self.dicts["text"])] | |
t2s_datasets = [ | |
TextToSpeechDataset( | |
manifest_path=f"{self.args.data}/{name}.tsv", | |
sample_rate=self.args.sample_rate, | |
label_paths=[f"{self.args.hubert_label_dir}/{name}.txt"], | |
label_processors=procs, | |
max_keep_sample_size=self.max_pos[0], | |
normalize=self.args.normalize, | |
store_labels=False, | |
src_dict=self.dicts["text"], | |
tokenizer=bpe_tokenizer, | |
reduction_factor=self.args.reduction_factor, | |
inference=self.args.inference_speech, | |
) | |
for name in split.split(",") | |
] | |
self.datasets[split] = ConcatDataset(t2s_datasets) if len(t2s_datasets) > 1 else t2s_datasets[0] | |
elif self.t5_task == "s2s": | |
manifest = f"{self.args.data}/{split}.tsv" | |
self.datasets[split] = SpeechToSpeechDataset( | |
manifest_path=manifest, | |
sample_rate=self.args.sample_rate, | |
max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size, | |
min_keep_sample_size=self.args.min_speech_sample_size, | |
normalize=self.args.normalize, | |
reduction_factor=self.args.reduction_factor, | |
) | |
elif self.t5_task == "s2c": | |
is_train_split = ("train" in split) | |
is_valid_split = ("valid" in split) | |
if is_train_split: | |
max_length = 51200 | |
elif is_valid_split: | |
max_length = 76800 | |
else: | |
max_length = 2560000 | |
manifest = op.join(f"{self.args.data}", f"{split}.tsv") | |
procs = LabelEncoder(self.dicts["text"]) # map speaker to id | |
self.datasets[split] = SpeechToClassDataset( | |
manifest_path=manifest, | |
sample_rate=self.args.sample_rate, | |
label_processors=procs, | |
max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size, | |
min_keep_sample_size=self.args.min_speech_sample_size, | |
normalize=self.args.normalize, | |
tgt_dict=self.dicts["text"], | |
max_length=max_length | |
) | |
elif self.t5_task == "pretrain": | |
is_train_split = ("train" in split) | |
pretrain_datasets = [] | |
speech_split, text_split = split.split('|') | |
## Speech pre-train | |
manifest = f"{self.args.data}/{speech_split}.tsv" | |
dicts = self.dicts["hubert"] | |
pad_list = [dict.pad() for dict in dicts] | |
eos_list = [dict.eos() for dict in dicts] | |
procs = [LabelEncoder(dict) for dict in dicts] | |
paths = [ | |
f"{self.args.hubert_label_dir}/{speech_split}.{l}" for l in self.args.hubert_labels | |
] | |
# hubert v1: pad_audio=True, random_crop=False; | |
self.args.dec_weight = getattr(self.args, "dec_weight", 1.0) | |
pretrain_datasets.append( | |
SpeechPretrainDataset( | |
manifest, | |
sample_rate=self.args.sample_rate, | |
label_paths=paths, | |
label_rates=self.args.label_rates, | |
pad_list=pad_list, | |
eos_list=eos_list, | |
label_processors=procs, | |
max_keep_sample_size=None, | |
min_keep_sample_size=32000, | |
max_sample_size=self.args.max_speech_sample_size, | |
pad_audio=self.args.pad_audio, | |
normalize=self.args.normalize, | |
store_labels=False, | |
random_crop=self.args.random_crop, | |
single_target=self.args.single_target, | |
reduction_factor=self.args.reduction_factor, | |
) | |
) | |
sample_ratios.append(sum([pretrain_datasets[0].size(i) for i in range(len(pretrain_datasets[0]))])) | |
## Text pre-train | |
paths = utils.split_paths(self.args.data) | |
assert len(paths) > 0 | |
data_path = paths[(epoch - 1) % len(paths)] | |
print(f"Loading {text_split} from data_path={data_path}") | |
split_path = op.join(data_path, text_split) | |
print(f"split_path={split_path}") | |
bart_dataset = data_utils.load_indexed_dataset( | |
split_path, | |
self.dicts["text"], | |
self.args.dataset_impl, | |
combine=combine, | |
) | |
if bart_dataset is None: | |
raise FileNotFoundError( | |
"Dataset not found: {} ({})".format(text_split, split_path) | |
) | |
bart_dataset = StripTokenDataset(bart_dataset, self.dicts["text"].eos()) | |
bart_dataset = maybe_shorten_dataset( | |
bart_dataset, | |
text_split, | |
self.args.shorten_data_split_list, | |
self.args.shorten_method, | |
self.args.tokens_per_sample, | |
self.args.seed, | |
) | |
# create continuous blocks of tokens | |
bart_dataset = TokenBlockDataset( | |
bart_dataset, | |
bart_dataset.sizes, | |
self.args.tokens_per_sample - 2, # one less for <s> and one for </s> | |
pad=self.dicts["text"].pad(), | |
eos=self.dicts["text"].eos(), | |
break_mode=self.args.sample_break_mode, | |
document_sep_len=0, | |
) | |
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) | |
bart_dataset = PrependTokenDataset(bart_dataset, self.dicts["text"].bos()) | |
bart_dataset = AppendTokenDataset(bart_dataset, self.dicts["text"].eos()) | |
mask_whole_words = ( | |
get_whole_word_mask(self.args, self.dicts["text"]) | |
if self.args.mask_length != "subword" | |
else None | |
) | |
self.args.bert_weight = getattr(self.args, "bert_weight", 0.0) | |
pretrain_datasets.append( | |
TextPretrainDataset( | |
bart_dataset, | |
bart_dataset.sizes, | |
self.dicts["text"], | |
self.mask_idx, | |
mask_whole_words, | |
shuffle=self.args.shuffle_instance, | |
seed=self.seed, | |
args=self.args, | |
iid_noise_target=self.args.iid_noise_target, | |
uni_mask_idxs=self.uni_mask_idxs if self.args.iid_noise_target else None, | |
) | |
) | |
sample_ratios.append(sum(pretrain_datasets[1].sizes)) | |
logger.info( | |
"Task: {0}, Loaded {1} samples of denoising_dataset".format( | |
'bart', | |
len(pretrain_datasets[1]), | |
) | |
) | |
logger.info('token ratio is ' + str(sample_ratios)) | |
if self.args.batch_ratio is not None: | |
batch_ratio = eval(self.args.batch_ratio) | |
assert len(batch_ratio) == len(sample_ratios) | |
sample_ratios = [sample_ratios[i] / batch_ratio[i] for i in range(len(sample_ratios))] | |
else: | |
batch_ratio = None | |
max_size = max(sample_ratios) | |
sample_ratios = [max_size / r for r in sample_ratios] | |
if hasattr(self.args, "sample_ratios") and self.args.sample_ratios is not None: | |
sample_ratios = eval(self.args.sample_ratios) | |
if is_train_split: | |
self.datasets[split] = MultitaskDataset( | |
pretrain_datasets, sample_ratios, batch_ratio | |
) | |
else: | |
self.datasets[split] = MultitaskDataset( | |
pretrain_datasets, batch_ratio=batch_ratio | |
) | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
model.train() | |
model.set_num_updates(update_num) | |
# Junyi: not use sample_size, but normalize the loss locally | |
agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, {} | |
agg_logging_output['sample_size'] = 1 | |
def forward_backward(model, samples, weight=1.0): | |
nonlocal agg_loss, agg_logging_output | |
if samples is None or len(samples) == 0: | |
return | |
loss, sample_size, logging_output = criterion(model, samples) | |
if ignore_grad: | |
loss *= 0 | |
else: | |
loss *= weight | |
loss = loss / sample_size | |
optimizer.backward(loss) | |
agg_loss += loss.detach().item() | |
# # TODO make summing of the sample sizes configurable | |
for k in logging_output: | |
if k == 'ntokens' or k == 'nsentences': | |
if k not in agg_logging_output: | |
agg_logging_output[k] = 0 | |
agg_logging_output[k] += logging_output[k] | |
# continue | |
# agg_logging_output[k] += logging_output[k] | |
# agg_logging_output[task_name] += logging_output[k] | |
agg_logging_output[samples['task_name']] = logging_output | |
forward_backward(model, sample) | |
agg_logging_output["loss"] = agg_loss | |
return agg_loss, agg_sample_size, agg_logging_output | |
def valid_step(self, sample, model, criterion): | |
model.eval() | |
with torch.no_grad(): | |
from collections import defaultdict | |
agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, defaultdict(float) | |
agg_logging_output['sample_size'] = 1 | |
loss, sample_size, logging_output = criterion(model, sample) | |
loss = loss / sample_size | |
# agg_loss += loss.data.item() if isinstance(loss, torch.Tensor) else loss | |
agg_loss += loss.item() if isinstance(loss, torch.Tensor) else loss | |
agg_logging_output[sample['task_name']] = logging_output | |
agg_logging_output["loss"] = agg_loss | |
return agg_loss, agg_sample_size, agg_logging_output | |
def target_dictionary(self): | |
return self.dicts["text"] | |
def source_dictionary(self): | |
return None | |
def build_model(self, args): | |
try: | |
args.input_feat_per_channel = self.config.input_feat_per_channel | |
args.input_channels = self.config.input_channels | |
except Exception as e: | |
args.input_feat_per_channel = 80 | |
args.input_channels = 1 | |
logger.info(f"Cannot set input_feat_per_channel, input_channels, since: ") | |
logger.warn(e) | |
logger.info(f"Set to: {args.input_feat_per_channel} and {args.input_channels}") | |
args.speech_odim = args.input_feat_per_channel * args.input_channels | |
args.label_rates = self.args.label_rates | |
args.sample_rate = self.args.sample_rate | |
self.args.reduction_factor = args.reduction_factor | |
return super(ArTSTTask, self).build_model(args) | |
def build_generator( | |
self, | |
models, | |
args, | |
seq_gen_cls=None, | |
extra_gen_cls_kwargs=None, | |
): | |
from artst.sequence_generator import SequenceGenerator | |
extra_gen_cls_kwargs = { | |
"ctc_weight": self.args.ctc_weight, | |
**extra_gen_cls_kwargs | |
} | |
return super().build_generator( | |
models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
) | |
def build_tokenizer(self, args): | |
if self.config is None: | |
logger.info(f"pre-tokenizer: None") | |
return encoders.build_tokenizer(Namespace(**{"tokenizer": None})) | |
else: | |
logger.info(f"pre-tokenizer: {self.config.pre_tokenizer}") | |
return encoders.build_tokenizer(Namespace(**self.config.pre_tokenizer)) | |
def build_bpe(self, args): | |
if self.config is not None: | |
logger.info(f"tokenizer: {self.config.bpe_tokenizer}") | |
return encoders.build_bpe(Namespace(**self.config.bpe_tokenizer)) | |
else: | |
logger.info(f"tokenizer: {self.args.bpe_tokenizer}") | |
return encoders.build_bpe(Namespace(**{"bpe": "sentencepiece", "sentencepiece_model": self.args.bpe_tokenizer})) | |
def generate_class(self, models, net_input, prefix_tokens, **kwargs): | |
with torch.no_grad(): | |
encoder_input = { | |
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name" | |
} | |
encoder_input.update(kwargs) | |
encoder_input.update({"prev_output_tokens": prefix_tokens}) | |
return models[0].generate_class(**encoder_input) | |
def generate_speech(self, models, net_input, **kwargs): | |
with torch.no_grad(): | |
encoder_input = { | |
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name" | |
} | |
encoder_input.update(kwargs) | |
return models[0].generate_speech(**encoder_input) | |
def inference_t2s( | |
self, models, sample | |
): | |
with torch.no_grad(): | |
xs = sample['net_input']['src_tokens'] | |
spkemb = sample['net_input']['spkembs'] | |
return models[0].inference(xs, spkemb) | |
def inference_s2s( | |
self, models, sample, force_equal_length=False | |
): | |
with torch.no_grad(): | |
x = sample['net_input']['src_tokens'] | |
xlen = sample['net_input']['src_lengths'] | |
spkemb = sample['net_input']['spkembs'] | |
prev_output_tokens = sample['net_input']['prev_output_tokens'] | |
padding_mask = sample['net_input']['padding_mask'] | |
tgt_lengths = sample['net_input']['tgt_lengths'] | |
return models[0].inference_s2s(x, xlen, spkemb, prev_output_tokens, tgt_lengths, force_equal_length=force_equal_length, padding_mask=padding_mask) | |
def inference_s2c( | |
self, models, sample | |
): | |
with torch.no_grad(): | |
x = sample['net_input']['src_tokens'] | |
xlen = sample['net_input']['src_lengths'] | |
prev_output_tokens = sample['net_input']['prev_output_tokens'] | |
padding_mask = sample['net_input']['padding_mask'] | |
assert prev_output_tokens.size(1) == 1, prev_output_tokens.size() | |
return models[0].inference_s2c(x, xlen, prev_output_tokens, padding_mask=padding_mask) | |
def filter_indices_by_size( | |
self, indices, dataset, max_positions=None, ignore_invalid_inputs=False | |
): | |
""" | |
Filter examples that are too large | |
Args: | |
indices (np.array): original array of sample indices | |
dataset (~fairseq.data.FairseqDataset): dataset to batch | |
max_positions (optional): max sentence length supported by the | |
model (default: None). | |
ignore_invalid_inputs (bool, optional): don't raise Exception for | |
sentences that are too long (default: False). | |
Returns: | |
np.array: array of filtered sample indices | |
""" | |
indices, ignored = dataset.filter_indices_by_size( | |
indices, | |
self.max_pos | |
) | |
return indices | |