File size: 4,714 Bytes
0766044
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse

import torch.cuda
from datasets import load_dataset
import json
import os
import transformers
from torch.utils.data import Subset
import wandb
import numpy as np
import gc

from models.personality_clustering import PersonalityClustering
from util.bloom_trainer import BloomTrainer
from util.data import PersonaChatDataset
from util.metrics import perplexity

from petals.client.remote_model import DistributedBloomForCausalLM

"""Пример запуска
python -m scripts.train_bloom_personachat --persona-ids 6 --config scripts/config.json --prompt-path data/models/
"""

DEFAULT_CLUSTERING_MODEL = './data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl'
MAX_VAL_DATA_SIZE = 4


def load_config(path):
    with open(path, 'r') as f:
        config = json.load(f)
    return argparse.Namespace(**config)


def main():
    args = parse_args()
    persona_clustering = PersonalityClustering()
    persona_clustering.load(args.clustering_model_path)

    config = load_config(args.config)

    tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
    tokenizer.padding_side = config.PADDING_SIDE
    tokenizer.model_max_length = config.MODEL_MAX_LENGTH

    dataset = load_dataset(config.PERSONACHAT_DATASET_NAME)
    personachat_train_dataset = PersonaChatDataset(persona_clustering,
                                                   dataset['train'],
                                                   tokenizer)
    personachat_val_dataset = PersonaChatDataset(persona_clustering,
                                                 dataset['validation'],
                                                 tokenizer)

    for id in args.persona_ids:
        prompt_path = os.path.join(args.prompt_path, f'{id}_persona_prompt_embedding.pt')
        train_dataset = personachat_train_dataset[id]
        val_dataset = personachat_val_dataset[id]
        honest_validation = True
        if len(val_dataset) < 4:
            val_dataset = personachat_train_dataset[id]
            honest_validation = False
        # для ускорения обрежем размер валидации до некоторой границы
        if len(val_dataset) > MAX_VAL_DATA_SIZE:
            subset_indexes = np.random.choice(len(val_dataset), MAX_VAL_DATA_SIZE, replace=False)
            val_dataset = Subset(val_dataset, subset_indexes)
        # train_dataset.shuffle()

        wandb_run = wandb.init(
            project=args.wandb_project,
            config={
                'lr': config.LR,
                'batch_size': config.BATCH_SIZE,
                'persona_id': id,
                'device': config.DEVICE,
                'model_name': config.MODEL_NAME,
                'n_epoch': config.N_EPOCH,
                'honest_validation': honest_validation
            },
            name=f'id{id}',
            reinit=True
        )
        if len(config.INITIAL_PEERS) == 0:
            model = DistributedBloomForCausalLM.from_pretrained(
                config.MODEL_NAME,
                pre_seq_len=config.NUM_PREFIX_TOKENS,
                tuning_mode=config.TUNING_MODE
            ).to(config.DEVICE)
        else:
            model = DistributedBloomForCausalLM.from_pretrained(
                config.MODEL_NAME,
                initial_peers=config.INITIAL_PEERS,
                pre_seq_len=config.NUM_PREFIX_TOKENS,
                tuning_mode=config.TUNING_MODE
            ).to(config.DEVICE)

        trainer = BloomTrainer(model, config, train_dataset, val_dataset, wandb_run, prompt_path)
        trainer.train()
        eval_perplexity = trainer.evaluate(perplexity)
        trainer.save_model(prompt_path)
        wandb_run.log({'perplexity': eval_perplexity, 'model_path': prompt_path})

        del model
        gc.collect()
        torch.cuda.empty_cache()


def parse_args(args=None):
    parser = argparse.ArgumentParser(add_help=True,
                                     description="bloom training script")
    parser.add_argument('--persona-ids', type=int, nargs='+',
                        help='Ids of persona')
    parser.add_argument('-clustering-model-path', '--clustering-model-path', type=str,
                        default=DEFAULT_CLUSTERING_MODEL,
                        help='Path to clustering model')
    parser.add_argument('--config', type=str, help='Path to training config file')
    parser.add_argument('--prompt-path', type=str,
                        help='Path to dir with trained soft prompts')
    parser.add_argument('--wandb-project', type=str, default='test_bloom_personachat_176b_v3')
    args = parser.parse_args(args)
    return args


if __name__ == '__main__':
    main()