UniMax Language Dataset Sampler with DDP support
This repository contains an unofficial implementation of the UNIMAX sampling algorithm using PyTorch. The UNIMAX algorithm "UniMax: Fairer and more Effective Language Sampling for Large-Scale Multilingual Pretraining" by HW Chung et al. (ICLR 2023) is used to generate a sampling distribution of languages based on their character counts, a total character budget, and a specified number of epochs per language. This can be useful for training language models on datasets with imbalanced language distribution.
Contents
unimax_sampler.py
: This Python file contains theUnimaxSampler
class, a PyTorchSampler
that uses the UNIMAX algorithm.test_unimax_sampler.py
: This Python file contains a unit test for theUnimaxSampler
class to ensure its correct functionality.
Usage
from torch.utils.data import Dataset, DataLoader
from unimax_sampler import UnimaxSampler
# Define your parameters
language_character_counts = [100, 200, 300, 400, 500]
total_character_budget = 1000
num_epochs = 2
# Create the UnimaxSampler
unimax_sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs)
Then, use the sampler as the sampler argument when creating a DataLoader.
# Disable shuffle when using custom sampler...
data_loader = DataLoader(my_dataset, batch_size=2, shuffle=None, sampler=unimax_sampler)
For DDP,
if torch.distributed.is_initialized():
sampler = DistributedUnimaxSampler(...)
else:
return unimax_sampler(...)
Note
The initial version of this code was created by Chat GPT-4, based on the pseudocode provided in the UNIMAX paper. Subsequently, the code was manually revised for PyTorch
Distributed Data Parallel (DDP) framework. The DistributedSamplerWrapper implementation is derived from an earlier version found in the Catalyst project.
License
This project is licensed under the MIT License.