mtasic85's picture
train
68a8a9f
|
raw
history blame
1.38 kB
# Train
## Environment
```bash
cd scripts
python -m venv venv
source venv/bin/activate
pip install -U -r requirements.in
```
## Tokenizer
```bash
python -B train_tokenizer.py
```
## Dataset
```bash
python -B prepare_pretrain_dataset.py
python -B prepare_contrain_dataset.py
```
## Model
### Pretraining
```bash
litgpt pretrain --config ./pretrain-model.yaml
litgpt convert_from_litgpt out/pretrain/final/ out/converted_pretrain
cp config.json out/pretrain/final/
cp config.json out/converted_pretrain/
```
```python
import torch
from safetensors.torch import save_file
state_dict = torch.load('out/converted_pretrain/model.pth', map_location='cpu')
save_file(state_dict, 'out/converted_pretrain/model.safetensors')
```
### Continued Pretraining
```bash
litgpt convert_pretrained_checkpoint out/pretrain/final/ out/pretrain_checkpoint/final/
cp config.json out/pretrain_checkpoint/final/
litgpt pretrain --config ./contrain-model.yaml
litgpt convert_from_litgpt out/contrain/final/ out/converted_contrain
cp config.json out/converted_contrain/
```
```python
import torch
from safetensors.torch import save_file
state_dict = torch.load('out/converted_contrain/model.pth', map_location='cpu')
save_file(state_dict, 'out/converted_contrain/model.safetensors')
```
```bash
cp out/converted_contrain/model.pth ./
cp out/converted_contrain/model.safetensors ./
```