|
# UniMax Language Dataset Sampler with DDP support |
|
|
|
This repository contains an unofficial implementation of the UNIMAX sampling algorithm using PyTorch. The UNIMAX algorithm ["UniMax: Fairer and more Effective Language Sampling for Large-Scale Multilingual Pretraining" by HW Chung et al. (ICLR 2023)](https://arxiv.org/abs/2304.09151) is used to generate a sampling distribution of languages based on their character counts, a total character budget, and a specified number of epochs per language. This can be useful for training language models on datasets with imbalanced language distribution. |
|
|
|
## Contents |
|
|
|
1. `unimax_sampler.py`: This Python file contains the `UnimaxSampler` class, a PyTorch `Sampler` that uses the UNIMAX algorithm. |
|
|
|
2. `test_unimax_sampler.py`: This Python file contains a unit test for the `UnimaxSampler` class to ensure its correct functionality. |
|
|
|
## Usage |
|
|
|
```python |
|
from torch.utils.data import Dataset, DataLoader |
|
from unimax_sampler import UnimaxSampler |
|
|
|
# Define your parameters |
|
language_character_counts = [100, 200, 300, 400, 500] |
|
total_character_budget = 1000 |
|
num_epochs = 2 |
|
|
|
# Create the UnimaxSampler |
|
unimax_sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs) |
|
``` |
|
|
|
Then, use the sampler as the sampler argument when creating a DataLoader. |
|
|
|
```python |
|
# Disable shuffle when using custom sampler... |
|
data_loader = DataLoader(my_dataset, batch_size=2, shuffle=None, sampler=unimax_sampler) |
|
``` |
|
|
|
For DDP, |
|
```python |
|
if torch.distributed.is_initialized(): |
|
sampler = DistributedUnimaxSampler(...) |
|
else: |
|
return unimax_sampler(...) |
|
``` |
|
|
|
## Note |
|
The initial version of this code was created by [Chat GPT-4](https://chat.openai.com/), based on the pseudocode provided in the [UNIMAX](https://arxiv.org/abs/2304.09151) paper. Subsequently, the code was manually revised for `PyTorch` Distributed Data Parallel ([DDP](https://pytorch.org/docs/stable/notes/ddp.html)) framework. The DistributedSamplerWrapper implementation is derived from an earlier version found in the [Catalyst](https://github.com/catalyst-team/catalyst) project. |
|
|
|
## License |
|
This project is licensed under the MIT License. |