File size: 6,749 Bytes
3f61ae8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import torch
import logging
import transformers
import torch.distributed as dist
import torch
import math

# global var
_SEQUENCE_PARALLEL_GROUP = None
_SEQUENCE_PARALLEL_SIZE = 1

def init_logger(fpath='', local_rank=0):
    if transformers.trainer_utils.is_main_process(local_rank):
        if fpath:
            if os.path.dirname(fpath):
                os.makedirs(os.path.dirname(fpath), exist_ok=True)
            file_handler = logging.FileHandler(fpath, mode='a')  # to file
            transformers.logging.add_handler(file_handler)
        transformers.logging.set_verbosity_info()
    else:
        transformers.logging.set_verbosity_error()  # reduce
    transformers.logging.enable_explicit_format()
    return transformers.logging.get_logger()

class DistributedSampler(torch.utils.data.distributed.DistributedSampler):
    def set_epoch(self, epoch):
        # 重载Sample 保证每个epoch dataset更新后sampler 重新更新
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas
        super().set_epoch(epoch)

def add_custom_callback(trainer, logger):
    if 'PrinterCallback' in trainer.callback_handler.callback_list:
        trainer.pop_callback(transformers.PrinterCallback)
    trainer.add_callback(LogCallback(logger))
    logger.info('Add custom LogCallback')
    trainer.add_callback(DatasetUpdateCallback(trainer))
    logger.info('Add custom DatasetUpdateCallback')
    trainer.add_callback(SaveDiskCallback())
    logger.info('Add custom SaveDiskCallback')
    logger.info(f"trainer's callbacks: {trainer.callback_handler.callback_list}")


class LogCallback(transformers.TrainerCallback):
    """
    A bare :class:`~transformers.TrainerCallback` that just prints with logger.
    """
    def __init__(self, logger, exclude=('total_flos', 'epoch')):
        self.logger = logger
        self.exclude = exclude

    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_world_process_zero:
            self.logger.info(''.join([
                f"[global_steps={state.global_step}]",
                f"[epochs={logs['epoch']}]",
                ','.join(f'{k}={v}' for k, v in logs.items()
                         if k not in self.exclude)
            ]))


class DatasetUpdateCallback(transformers.TrainerCallback):
    def __init__(self, trainer):
        self.trainer = trainer

    def on_epoch_begin(self, args, state, control,train_dataloader, **kwargs):
        self.trainer.train_dataset.update(int(state.epoch))
        train_dataloader.sampler.set_epoch(int(state.epoch))


class SaveDiskCallback(transformers.TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        if args.local_rank != 0:
            return

        for ckpt in os.listdir(args.output_dir):
            # remove out-of-date deepspeed checkpoints
            if ckpt.startswith('checkpoint-') and not ckpt.endswith(f'-{state.global_step}'):
                for pattern in ['global_step*', '*.pth']:
                    os.system("rm -rf " + os.path.join(args.output_dir, ckpt, pattern))

    def on_train_end(self, args, state, control, **kwargs):
        if state.is_local_process_zero and False:
            for pattern in ['global_step*', '*.pth']:
                os.system("rm -rf " + os.path.join(args.output_dir, "checkpoint-*", pattern))


def register_nan_hook(model):
    torch.autograd.set_detect_anomaly(True)

    def add_module_name(module):
        for name, sub_module in module.named_modules():
            sub_module.name = name

    def add_check_nan_hook(module):
        def check_nan(module, inputs, outputs):
            any_nan = False
            for i, tensor in enumerate(inputs):
                if isinstance(tensor, torch.Tensor) and tensor.isnan().any():
                    print(f'module {module.name} contains nan in its {i}th input.')
                    any_nan = True
            for i, tensor in enumerate(outputs):
                if isinstance(tensor, torch.Tensor) and tensor.isnan().any():
                    print(f'module {module.name} contains nan in its {i}th output.')
                    any_nan = True
            if any_nan:
                if torch.distributed.get_rank() == 0:
                    torch.save({
                        'state_dict': module.state_dict(),
                        'inputs': inputs,
                        'outputs': outputs,
                        'type': module.__class__.__name__
                    }, module.name + '.pth')
                    # from ipdb import set_trace; set_trace()
                # else:
                    # import time; time.sleep(10000)

        module.register_forward_hook(lambda module, inputs, outputs: check_nan(module, inputs, outputs))
        module.register_forward_hook(lambda module, inputs, outputs: check_nan(module, inputs, outputs))
    
    model.apply(add_module_name)
    model.apply(add_check_nan_hook)


def initialize_seq_parallel(
    sequence_parallel_size,
):
    if sequence_parallel_size <= 1:
        return None
    num_sequence_parallel_groups: int = dist.get_world_size() // sequence_parallel_size
    global _SEQUENCE_PARALLEL_GROUP
    global _SEQUENCE_PARALLEL_SIZE
    _SEQUENCE_PARALLEL_SIZE = sequence_parallel_size
    for i in range(num_sequence_parallel_groups):
        ranks = range(i * sequence_parallel_size,
                      (i + 1) * sequence_parallel_size)
        group = torch.distributed.new_group(ranks)
        if dist.get_rank() in ranks:
            _SEQUENCE_PARALLEL_GROUP = group

def get_sequence_parallel_group():
    """Get the sequence parallel group the caller rank belongs to."""
    return _SEQUENCE_PARALLEL_GROUP

def get_sequence_parallel_size():
    return _SEQUENCE_PARALLEL_SIZE

def get_sequence_parallel_rank():
    return torch.distributed.get_rank(group=get_sequence_parallel_group())

# 设置序列并行参数来保证优化器正确平均
from deepspeed.utils import groups
groups._get_sequence_parallel_world_size = get_sequence_parallel_size