HTML-Pruner-Phi-3.8B / seq_para_utils.py
zstanjj's picture
Upload 13 files
3f61ae8 verified
raw
history blame
6.75 kB
import os
import torch
import logging
import transformers
import torch.distributed as dist
import torch
import math
# global var
_SEQUENCE_PARALLEL_GROUP = None
_SEQUENCE_PARALLEL_SIZE = 1
def init_logger(fpath='', local_rank=0):
if transformers.trainer_utils.is_main_process(local_rank):
if fpath:
if os.path.dirname(fpath):
os.makedirs(os.path.dirname(fpath), exist_ok=True)
file_handler = logging.FileHandler(fpath, mode='a') # to file
transformers.logging.add_handler(file_handler)
transformers.logging.set_verbosity_info()
else:
transformers.logging.set_verbosity_error() # reduce
transformers.logging.enable_explicit_format()
return transformers.logging.get_logger()
class DistributedSampler(torch.utils.data.distributed.DistributedSampler):
def set_epoch(self, epoch):
# 重载Sample 保证每个epoch dataset更新后sampler 重新更新
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
super().set_epoch(epoch)
def add_custom_callback(trainer, logger):
if 'PrinterCallback' in trainer.callback_handler.callback_list:
trainer.pop_callback(transformers.PrinterCallback)
trainer.add_callback(LogCallback(logger))
logger.info('Add custom LogCallback')
trainer.add_callback(DatasetUpdateCallback(trainer))
logger.info('Add custom DatasetUpdateCallback')
trainer.add_callback(SaveDiskCallback())
logger.info('Add custom SaveDiskCallback')
logger.info(f"trainer's callbacks: {trainer.callback_handler.callback_list}")
class LogCallback(transformers.TrainerCallback):
"""
A bare :class:`~transformers.TrainerCallback` that just prints with logger.
"""
def __init__(self, logger, exclude=('total_flos', 'epoch')):
self.logger = logger
self.exclude = exclude
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero:
self.logger.info(''.join([
f"[global_steps={state.global_step}]",
f"[epochs={logs['epoch']}]",
','.join(f'{k}={v}' for k, v in logs.items()
if k not in self.exclude)
]))
class DatasetUpdateCallback(transformers.TrainerCallback):
def __init__(self, trainer):
self.trainer = trainer
def on_epoch_begin(self, args, state, control,train_dataloader, **kwargs):
self.trainer.train_dataset.update(int(state.epoch))
train_dataloader.sampler.set_epoch(int(state.epoch))
class SaveDiskCallback(transformers.TrainerCallback):
def on_save(self, args, state, control, **kwargs):
if args.local_rank != 0:
return
for ckpt in os.listdir(args.output_dir):
# remove out-of-date deepspeed checkpoints
if ckpt.startswith('checkpoint-') and not ckpt.endswith(f'-{state.global_step}'):
for pattern in ['global_step*', '*.pth']:
os.system("rm -rf " + os.path.join(args.output_dir, ckpt, pattern))
def on_train_end(self, args, state, control, **kwargs):
if state.is_local_process_zero and False:
for pattern in ['global_step*', '*.pth']:
os.system("rm -rf " + os.path.join(args.output_dir, "checkpoint-*", pattern))
def register_nan_hook(model):
torch.autograd.set_detect_anomaly(True)
def add_module_name(module):
for name, sub_module in module.named_modules():
sub_module.name = name
def add_check_nan_hook(module):
def check_nan(module, inputs, outputs):
any_nan = False
for i, tensor in enumerate(inputs):
if isinstance(tensor, torch.Tensor) and tensor.isnan().any():
print(f'module {module.name} contains nan in its {i}th input.')
any_nan = True
for i, tensor in enumerate(outputs):
if isinstance(tensor, torch.Tensor) and tensor.isnan().any():
print(f'module {module.name} contains nan in its {i}th output.')
any_nan = True
if any_nan:
if torch.distributed.get_rank() == 0:
torch.save({
'state_dict': module.state_dict(),
'inputs': inputs,
'outputs': outputs,
'type': module.__class__.__name__
}, module.name + '.pth')
# from ipdb import set_trace; set_trace()
# else:
# import time; time.sleep(10000)
module.register_forward_hook(lambda module, inputs, outputs: check_nan(module, inputs, outputs))
module.register_forward_hook(lambda module, inputs, outputs: check_nan(module, inputs, outputs))
model.apply(add_module_name)
model.apply(add_check_nan_hook)
def initialize_seq_parallel(
sequence_parallel_size,
):
if sequence_parallel_size <= 1:
return None
num_sequence_parallel_groups: int = dist.get_world_size() // sequence_parallel_size
global _SEQUENCE_PARALLEL_GROUP
global _SEQUENCE_PARALLEL_SIZE
_SEQUENCE_PARALLEL_SIZE = sequence_parallel_size
for i in range(num_sequence_parallel_groups):
ranks = range(i * sequence_parallel_size,
(i + 1) * sequence_parallel_size)
group = torch.distributed.new_group(ranks)
if dist.get_rank() in ranks:
_SEQUENCE_PARALLEL_GROUP = group
def get_sequence_parallel_group():
"""Get the sequence parallel group the caller rank belongs to."""
return _SEQUENCE_PARALLEL_GROUP
def get_sequence_parallel_size():
return _SEQUENCE_PARALLEL_SIZE
def get_sequence_parallel_rank():
return torch.distributed.get_rank(group=get_sequence_parallel_group())
# 设置序列并行参数来保证优化器正确平均
from deepspeed.utils import groups
groups._get_sequence_parallel_world_size = get_sequence_parallel_size