mimbres's picture
.
a03c9b4
|
raw
history blame
2.16 kB
# UniMax Language Dataset Sampler with DDP support
This repository contains an unofficial implementation of the UNIMAX sampling algorithm using PyTorch. The UNIMAX algorithm ["UniMax: Fairer and more Effective Language Sampling for Large-Scale Multilingual Pretraining" by HW Chung et al. (ICLR 2023)](https://arxiv.org/abs/2304.09151) is used to generate a sampling distribution of languages based on their character counts, a total character budget, and a specified number of epochs per language. This can be useful for training language models on datasets with imbalanced language distribution.
## Contents
1. `unimax_sampler.py`: This Python file contains the `UnimaxSampler` class, a PyTorch `Sampler` that uses the UNIMAX algorithm.
2. `test_unimax_sampler.py`: This Python file contains a unit test for the `UnimaxSampler` class to ensure its correct functionality.
## Usage
```python
from torch.utils.data import Dataset, DataLoader
from unimax_sampler import UnimaxSampler
# Define your parameters
language_character_counts = [100, 200, 300, 400, 500]
total_character_budget = 1000
num_epochs = 2
# Create the UnimaxSampler
unimax_sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs)
```
Then, use the sampler as the sampler argument when creating a DataLoader.
```python
# Disable shuffle when using custom sampler...
data_loader = DataLoader(my_dataset, batch_size=2, shuffle=None, sampler=unimax_sampler)
```
For DDP,
```python
if torch.distributed.is_initialized():
sampler = DistributedUnimaxSampler(...)
else:
return unimax_sampler(...)
```
## Note
The initial version of this code was created by [Chat GPT-4](https://chat.openai.com/), based on the pseudocode provided in the [UNIMAX](https://arxiv.org/abs/2304.09151) paper. Subsequently, the code was manually revised for `PyTorch` Distributed Data Parallel ([DDP](https://pytorch.org/docs/stable/notes/ddp.html)) framework. The DistributedSamplerWrapper implementation is derived from an earlier version found in the [Catalyst](https://github.com/catalyst-team/catalyst) project.
## License
This project is licensed under the MIT License.