Spaces:
Runtime error
Runtime error
# modalified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/concat_dataset.py | |
import bisect | |
import numpy as np | |
from torch.utils.data.dataloader import default_collate | |
from fairseq.data import FairseqDataset | |
class ConcatDataset(FairseqDataset): | |
def cumsum(sequence, sample_ratios): | |
r, s = [], 0 | |
for e, ratio in zip(sequence, sample_ratios): | |
curr_len = int(ratio * len(e)) | |
r.append(curr_len + s) | |
s += curr_len | |
return r | |
def __init__(self, datasets, sample_ratios=1): | |
super(ConcatDataset, self).__init__() | |
assert len(datasets) > 0, "datasets should not be an empty iterable" | |
self.datasets = list(datasets) | |
if isinstance(sample_ratios, int): | |
sample_ratios = [sample_ratios] * len(self.datasets) | |
self.sample_ratios = sample_ratios | |
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) | |
self.real_sizes = [len(d) for d in self.datasets] | |
def __len__(self): | |
return self.cumulative_sizes[-1] | |
def __getitem__(self, idx): | |
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) | |
return self.datasets[dataset_idx][sample_idx] | |
def _get_dataset_and_sample_index(self, idx: int): | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
sample_idx = sample_idx % self.real_sizes[dataset_idx] | |
return dataset_idx, sample_idx | |
def collater(self, samples, **extra_args): | |
# For now only supports datasets with same underlying collater implementations | |
if hasattr(self.datasets[0], "collater"): | |
return self.datasets[0].collater(samples, **extra_args) | |
else: | |
return default_collate(samples, **extra_args) | |
def size(self, idx: int): | |
""" | |
Return an example's size as a float or tuple. | |
""" | |
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) | |
return self.datasets[dataset_idx].size(sample_idx) | |
def num_tokens(self, index: int): | |
return np.max(self.size(index)) | |
def attr(self, attr: str, index: int): | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) | |
return getattr(self.datasets[dataset_idx], attr, None) | |
def sizes(self): | |
_dataset_sizes = [] | |
for ds, sr in zip(self.datasets, self.sample_ratios): | |
if isinstance(ds.sizes, np.ndarray): | |
_dataset_sizes.append(np.tile(ds.sizes, sr)) | |
else: | |
# Only support underlying dataset with single size array. | |
assert isinstance(ds.sizes, list) | |
_dataset_sizes.append(np.tile(ds.sizes[0], sr)) | |
return np.concatenate(_dataset_sizes) | |
def supports_prefetch(self): | |
return all(d.supports_prefetch for d in self.datasets) | |
def ordered_indices(self): | |
""" | |
Returns indices sorted by length. So less padding is needed. | |
""" | |
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1: | |
# special handling for concatenating lang_pair_datasets | |
if getattr(self.datasets[0], "shuffle", False): | |
indices = np.random.permutation(len(self)).astype(np.int64) | |
else: | |
indices = np.arange(len(self), dtype=np.int64) | |
sizes = self.sizes | |
tgt_sizes = ( | |
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None | |
) | |
src_sizes = ( | |
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes | |
) | |
# sort by target length, then source length | |
if tgt_sizes is not None: | |
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] | |
return indices[np.argsort(src_sizes[indices], kind="mergesort")] | |
else: | |
return np.argsort(self.sizes) | |
def prefetch(self, indices): | |
frm = 0 | |
for to, ds in zip(self.cumulative_sizes, self.datasets): | |
real_size = len(ds) | |
if getattr(ds, "supports_prefetch", False): | |
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) | |
frm = to | |
def can_reuse_epoch_itr_across_epochs(self): | |
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) | |
def set_epoch(self, epoch): | |
super().set_epoch(epoch) | |
for ds in self.datasets: | |
if hasattr(ds, "set_epoch"): | |
ds.set_epoch(epoch) | |