|
import os |
|
import torch |
|
import logging |
|
import transformers |
|
import torch.distributed as dist |
|
import torch |
|
import math |
|
|
|
|
|
_SEQUENCE_PARALLEL_GROUP = None |
|
_SEQUENCE_PARALLEL_SIZE = 1 |
|
|
|
def init_logger(fpath='', local_rank=0): |
|
if transformers.trainer_utils.is_main_process(local_rank): |
|
if fpath: |
|
if os.path.dirname(fpath): |
|
os.makedirs(os.path.dirname(fpath), exist_ok=True) |
|
file_handler = logging.FileHandler(fpath, mode='a') |
|
transformers.logging.add_handler(file_handler) |
|
transformers.logging.set_verbosity_info() |
|
else: |
|
transformers.logging.set_verbosity_error() |
|
transformers.logging.enable_explicit_format() |
|
return transformers.logging.get_logger() |
|
|
|
class DistributedSampler(torch.utils.data.distributed.DistributedSampler): |
|
def set_epoch(self, epoch): |
|
|
|
|
|
|
|
if self.drop_last and len(self.dataset) % self.num_replicas != 0: |
|
|
|
|
|
|
|
self.num_samples = math.ceil( |
|
(len(self.dataset) - self.num_replicas) / self.num_replicas |
|
) |
|
else: |
|
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) |
|
self.total_size = self.num_samples * self.num_replicas |
|
super().set_epoch(epoch) |
|
|
|
def add_custom_callback(trainer, logger): |
|
if 'PrinterCallback' in trainer.callback_handler.callback_list: |
|
trainer.pop_callback(transformers.PrinterCallback) |
|
trainer.add_callback(LogCallback(logger)) |
|
logger.info('Add custom LogCallback') |
|
trainer.add_callback(DatasetUpdateCallback(trainer)) |
|
logger.info('Add custom DatasetUpdateCallback') |
|
trainer.add_callback(SaveDiskCallback()) |
|
logger.info('Add custom SaveDiskCallback') |
|
logger.info(f"trainer's callbacks: {trainer.callback_handler.callback_list}") |
|
|
|
|
|
class LogCallback(transformers.TrainerCallback): |
|
""" |
|
A bare :class:`~transformers.TrainerCallback` that just prints with logger. |
|
""" |
|
def __init__(self, logger, exclude=('total_flos', 'epoch')): |
|
self.logger = logger |
|
self.exclude = exclude |
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
if state.is_world_process_zero: |
|
self.logger.info(''.join([ |
|
f"[global_steps={state.global_step}]", |
|
f"[epochs={logs['epoch']}]", |
|
','.join(f'{k}={v}' for k, v in logs.items() |
|
if k not in self.exclude) |
|
])) |
|
|
|
|
|
class DatasetUpdateCallback(transformers.TrainerCallback): |
|
def __init__(self, trainer): |
|
self.trainer = trainer |
|
|
|
def on_epoch_begin(self, args, state, control,train_dataloader, **kwargs): |
|
self.trainer.train_dataset.update(int(state.epoch)) |
|
train_dataloader.sampler.set_epoch(int(state.epoch)) |
|
|
|
|
|
class SaveDiskCallback(transformers.TrainerCallback): |
|
def on_save(self, args, state, control, **kwargs): |
|
if args.local_rank != 0: |
|
return |
|
|
|
for ckpt in os.listdir(args.output_dir): |
|
|
|
if ckpt.startswith('checkpoint-') and not ckpt.endswith(f'-{state.global_step}'): |
|
for pattern in ['global_step*', '*.pth']: |
|
os.system("rm -rf " + os.path.join(args.output_dir, ckpt, pattern)) |
|
|
|
def on_train_end(self, args, state, control, **kwargs): |
|
if state.is_local_process_zero and False: |
|
for pattern in ['global_step*', '*.pth']: |
|
os.system("rm -rf " + os.path.join(args.output_dir, "checkpoint-*", pattern)) |
|
|
|
|
|
def register_nan_hook(model): |
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
def add_module_name(module): |
|
for name, sub_module in module.named_modules(): |
|
sub_module.name = name |
|
|
|
def add_check_nan_hook(module): |
|
def check_nan(module, inputs, outputs): |
|
any_nan = False |
|
for i, tensor in enumerate(inputs): |
|
if isinstance(tensor, torch.Tensor) and tensor.isnan().any(): |
|
print(f'module {module.name} contains nan in its {i}th input.') |
|
any_nan = True |
|
for i, tensor in enumerate(outputs): |
|
if isinstance(tensor, torch.Tensor) and tensor.isnan().any(): |
|
print(f'module {module.name} contains nan in its {i}th output.') |
|
any_nan = True |
|
if any_nan: |
|
if torch.distributed.get_rank() == 0: |
|
torch.save({ |
|
'state_dict': module.state_dict(), |
|
'inputs': inputs, |
|
'outputs': outputs, |
|
'type': module.__class__.__name__ |
|
}, module.name + '.pth') |
|
|
|
|
|
|
|
|
|
module.register_forward_hook(lambda module, inputs, outputs: check_nan(module, inputs, outputs)) |
|
module.register_forward_hook(lambda module, inputs, outputs: check_nan(module, inputs, outputs)) |
|
|
|
model.apply(add_module_name) |
|
model.apply(add_check_nan_hook) |
|
|
|
|
|
def initialize_seq_parallel( |
|
sequence_parallel_size, |
|
): |
|
if sequence_parallel_size <= 1: |
|
return None |
|
num_sequence_parallel_groups: int = dist.get_world_size() // sequence_parallel_size |
|
global _SEQUENCE_PARALLEL_GROUP |
|
global _SEQUENCE_PARALLEL_SIZE |
|
_SEQUENCE_PARALLEL_SIZE = sequence_parallel_size |
|
for i in range(num_sequence_parallel_groups): |
|
ranks = range(i * sequence_parallel_size, |
|
(i + 1) * sequence_parallel_size) |
|
group = torch.distributed.new_group(ranks) |
|
if dist.get_rank() in ranks: |
|
_SEQUENCE_PARALLEL_GROUP = group |
|
|
|
def get_sequence_parallel_group(): |
|
"""Get the sequence parallel group the caller rank belongs to.""" |
|
return _SEQUENCE_PARALLEL_GROUP |
|
|
|
def get_sequence_parallel_size(): |
|
return _SEQUENCE_PARALLEL_SIZE |
|
|
|
def get_sequence_parallel_rank(): |
|
return torch.distributed.get_rank(group=get_sequence_parallel_group()) |
|
|
|
|
|
from deepspeed.utils import groups |
|
groups._get_sequence_parallel_world_size = get_sequence_parallel_size |