|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Input/output checkpointing.""" |
|
|
|
import json |
|
import math |
|
import os |
|
import re |
|
import shutil |
|
import time |
|
import random |
|
import sys |
|
import numpy as np |
|
|
|
try: |
|
import boto3 |
|
except ModuleNotFoundError: |
|
print( |
|
"For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3" |
|
) |
|
try: |
|
import hf_transfer |
|
except ModuleNotFoundError: |
|
print( |
|
"For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer" |
|
) |
|
import torch |
|
from glob import glob |
|
|
|
from megatron import mpu |
|
from megatron import print_rank_0 |
|
from megatron.utils import natural_sort |
|
from megatron.text_generation_utils import get_batch, forward_model |
|
from pathlib import Path |
|
from pprint import pformat |
|
|
|
|
|
def check_checkpoint_args(neox_args, checkpoint_args): |
|
"""Ensure fixed arguments for a model are the same for the input |
|
arguments and the one retrieved from checkpoint.""" |
|
|
|
assert isinstance(checkpoint_args, dict), "args stored in checkpoint is a dict" |
|
for checkpoint_arg_name, checkpoint_arg_value in checkpoint_args.items(): |
|
args_value = getattr(neox_args, checkpoint_arg_name) |
|
error_message = "{} value from checkpoint ({}) is not equal to the currently set argument value ({}).".format( |
|
checkpoint_arg_name, checkpoint_arg_value, args_value |
|
) |
|
assert checkpoint_arg_value == args_value, error_message |
|
|
|
|
|
def do_forward_pass(neox_args, model, inference=False): |
|
|
|
|
|
model_was_in_train = model.training |
|
model.eval() |
|
|
|
|
|
|
|
context_tokens_tensor = ( |
|
torch.arange(neox_args.seq_length + 1) |
|
.repeat((neox_args.train_micro_batch_size_per_gpu, 1)) |
|
.cuda() |
|
) |
|
|
|
|
|
if inference: |
|
tokens, attention_mask, position_ids = get_batch( |
|
neox_args, context_tokens_tensor[:, : neox_args.seq_length] |
|
) |
|
model_inputs = ( |
|
tokens, |
|
position_ids, |
|
attention_mask, |
|
torch.Tensor(), |
|
) |
|
logits, _ = forward_model(neox_args, model, model_inputs) |
|
elif neox_args.is_pipe_parallel: |
|
data_iterator = iter([{"text": context_tokens_tensor}]) |
|
_, logits = model.eval_batch(data_iter=data_iterator, return_logits=True) |
|
else: |
|
tokens, attention_mask, position_ids = get_batch( |
|
neox_args, context_tokens_tensor[:, : neox_args.seq_length] |
|
) |
|
logits = model((tokens, position_ids, attention_mask)) |
|
|
|
|
|
if model_was_in_train: |
|
model.train() |
|
|
|
if logits is not None: |
|
logits = logits.detach().cpu()[ |
|
0 |
|
] |
|
|
|
return logits |
|
|
|
|
|
def check_forward_pass(neox_args, model, checkpoint_logits, inference): |
|
|
|
logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference) |
|
|
|
|
|
if ( |
|
logits is not None and checkpoint_logits is not None |
|
): |
|
if not (logits == checkpoint_logits).all().item(): |
|
if mpu.get_data_parallel_rank() == 0: |
|
print( |
|
" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result" |
|
) |
|
assert ( |
|
torch.isclose(logits, checkpoint_logits).all().item() |
|
), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result" |
|
|
|
|
|
def ensure_directory_exists(filename): |
|
"""Build filename's path if it does not already exists.""" |
|
dirname = os.path.dirname(filename) |
|
if not os.path.exists(dirname): |
|
os.makedirs(dirname) |
|
|
|
|
|
def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None): |
|
"""A unified checkpoint name.""" |
|
if release: |
|
directory = "release" |
|
else: |
|
directory = "iter_{:07d}".format(iteration) |
|
return os.path.join( |
|
checkpoints_path, |
|
directory, |
|
"mp_rank_{:02d}".format( |
|
mpu.get_model_parallel_rank() if mp_rank is None else mp_rank |
|
), |
|
"model_optim_rng.pt", |
|
) |
|
|
|
|
|
def get_checkpoint_tag(iteration: int) -> str: |
|
return f"global_step{iteration}" |
|
|
|
|
|
def delete_old_checkpoints(save_dir, n_to_keep): |
|
if torch.distributed.get_rank() == 0: |
|
ckpt_dir_regex = r"global_step[\d]*" |
|
if save_dir.endswith("/"): |
|
save_dir = save_dir.strip("/") |
|
all_ckpts = natural_sort( |
|
[ |
|
i |
|
for i in glob(f"{save_dir}/*") |
|
if os.path.isdir(i) and re.search(ckpt_dir_regex, i) |
|
] |
|
) |
|
n_to_delete = len(all_ckpts) - n_to_keep |
|
if n_to_delete > 0: |
|
to_delete = all_ckpts[:n_to_delete] |
|
print(f"WARNING: Deleting old checkpoints: \n\t{', '.join(to_delete)}") |
|
for ckpt in to_delete: |
|
try: |
|
shutil.rmtree(ckpt) |
|
except FileNotFoundError: |
|
pass |
|
|
|
|
|
def save_ds_checkpoint(iteration, model, neox_args): |
|
"""Save a model checkpoint.""" |
|
sd = { |
|
"iteration": iteration, |
|
"args": { |
|
"num_layers": neox_args.num_layers, |
|
"hidden_size": neox_args.hidden_size, |
|
"num_attention_heads": neox_args.num_attention_heads, |
|
"max_position_embeddings": neox_args.max_position_embeddings, |
|
"make_vocab_size_divisible_by": neox_args.make_vocab_size_divisible_by, |
|
"padded_vocab_size": neox_args.padded_vocab_size, |
|
"tokenizer_type": neox_args.tokenizer_type, |
|
"model_parallel_size": neox_args.model_parallel_size, |
|
}, |
|
} |
|
|
|
if not neox_args.no_save_rng: |
|
sd["random_rng_state"] = random.getstate() |
|
sd["np_rng_state"] = np.random.get_state() |
|
sd["torch_rng_state"] = torch.get_rng_state() |
|
sd["cuda_rng_state"] = torch.cuda.get_rng_state() |
|
sd["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states() |
|
|
|
if neox_args.checkpoint_validation_with_forward_pass: |
|
logits = do_forward_pass(neox_args=neox_args, model=model) |
|
sd["checkpoint_validation_logits"] = logits |
|
|
|
|
|
tag = get_checkpoint_tag(iteration) |
|
|
|
|
|
model.save_checkpoint(neox_args.save, tag=tag, client_state=sd) |
|
|
|
|
|
if torch.distributed.get_rank() == 0 and neox_args.config_files is not None: |
|
configs_directory = os.path.join(neox_args.save, tag, "configs") |
|
os.makedirs(configs_directory, exist_ok=True) |
|
for config_filename, config_data in neox_args.config_files.items(): |
|
with open(os.path.join(configs_directory, config_filename), "w") as f: |
|
if isinstance(config_data, str): |
|
f.write(config_data) |
|
else: |
|
json.dump(config_data, f) |
|
|
|
|
|
def multiprocessing_starmap(func, args, num_processes=None): |
|
"""Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling |
|
Args: |
|
func (Callable): Function to call |
|
args (Iterable): Iterable of arguments to pass to `func` |
|
num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1` |
|
""" |
|
import multiprocessing |
|
|
|
num_processes = num_processes or (multiprocessing.cpu_count() - 1) |
|
with multiprocessing.get_context("spawn").Pool( |
|
processes=num_processes |
|
) as process_pool: |
|
process_pool.starmap(func, args) |
|
process_pool.terminate() |
|
process_pool.join() |
|
del process_pool |
|
|
|
|
|
def _upload( |
|
file_path: str, |
|
s3_key: str, |
|
chunk_size: int = 104_857_600, |
|
max_files: int = 64, |
|
parallel_failures: int = 63, |
|
max_retries: int = 5, |
|
): |
|
"""Upload local file to S3 using `hf_transfer` library |
|
Args: |
|
file_path (str): Local filename to upload |
|
s3_key (str): S3 key to upload to. E.g. `s3://bucket-name/path/to/file` |
|
chunk_size (int, optional): Chunk size to use for multipart upload. |
|
Defaults to 100MiB = 104_857_600 |
|
max_files (int, optional): Number of open file handles, which determines |
|
the maximum number of parallel downloads. Defaults to 64 |
|
parallel_failures (int, optional): Number of maximum failures of different |
|
chunks in parallel (cannot exceed max_files). Defaults to 63 |
|
max_retries (int, optional): Number of retries for each chunk. Defaults to 5 |
|
""" |
|
s3 = boto3.client("s3") |
|
bucket = s3_key.split("s3://")[1].split("/")[0] |
|
key = s3_key.split(bucket)[1].lstrip("/") |
|
|
|
|
|
upload = s3.create_multipart_upload( |
|
ACL="bucket-owner-full-control", |
|
Bucket=bucket, |
|
Key=key, |
|
) |
|
upload_id = upload["UploadId"] |
|
|
|
|
|
file_size = os.stat(file_path).st_size |
|
urls = [] |
|
nb_parts = math.ceil(file_size / chunk_size) |
|
for part_number in range(1, nb_parts + 1): |
|
params = { |
|
"Bucket": bucket, |
|
"Key": key, |
|
"PartNumber": part_number, |
|
"UploadId": upload_id, |
|
} |
|
urls.append( |
|
s3.generate_presigned_url( |
|
ClientMethod="upload_part", Params=params, ExpiresIn=86400 |
|
) |
|
) |
|
|
|
|
|
responses = hf_transfer.multipart_upload( |
|
file_path=file_path, |
|
parts_urls=urls, |
|
chunk_size=chunk_size, |
|
max_files=max_files, |
|
parallel_failures=parallel_failures, |
|
max_retries=max_retries, |
|
) |
|
|
|
|
|
etag_with_parts = [] |
|
for part_number, header in enumerate(responses): |
|
etag = header.get("etag") |
|
etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1}) |
|
parts = {"Parts": etag_with_parts} |
|
s3.complete_multipart_upload( |
|
Bucket=bucket, Key=key, MultipartUpload=parts, UploadId=upload_id |
|
) |
|
|
|
|
|
def upload_checkpoint(iteration, neox_args): |
|
local_checkpoint_path = os.path.join( |
|
os.path.abspath(neox_args.save), get_checkpoint_tag(iteration) |
|
) |
|
local_checkpoint_list = sorted( |
|
filter( |
|
lambda x: os.path.isfile(x), |
|
[str(p) for p in Path(local_checkpoint_path).rglob("*")], |
|
) |
|
) |
|
remote_checkpoint_path = os.path.join( |
|
neox_args.s3_path, |
|
os.path.basename(neox_args.save), |
|
get_checkpoint_tag(iteration), |
|
) |
|
remote_checkpoint_list = [ |
|
os.path.join( |
|
remote_checkpoint_path, |
|
os.path.relpath(local_checkpoint, local_checkpoint_path), |
|
) |
|
for local_checkpoint in local_checkpoint_list |
|
] |
|
inputs = zip( |
|
local_checkpoint_list, |
|
remote_checkpoint_list, |
|
[neox_args.s3_chunk_size] * len(local_checkpoint_list), |
|
) |
|
|
|
print_rank_0( |
|
f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`..." |
|
) |
|
start = time.time() |
|
multiprocessing_starmap(_upload, inputs) |
|
total_time = time.time() - start |
|
print_rank_0( |
|
f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s" |
|
) |
|
|
|
|
|
def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler): |
|
"""Save a model checkpoint.""" |
|
|
|
if neox_args.deepspeed: |
|
save_ds_checkpoint(iteration, model, neox_args) |
|
else: |
|
raise ValueError("Must be using deepspeed to use neox") |
|
|
|
torch.distributed.barrier() |
|
upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None |
|
if upload_to_s3: |
|
upload_checkpoint(iteration, neox_args) |
|
|
|
|
|
torch.distributed.barrier() |
|
if neox_args.keep_last_n_checkpoints is not None: |
|
delete_old_checkpoints(neox_args.save, neox_args.keep_last_n_checkpoints) |
|
|
|
|
|
torch.distributed.barrier() |
|
|
|
|
|
def load_checkpoint( |
|
neox_args, model, optimizer, lr_scheduler, inference=False, iteration=None |
|
): |
|
"""Load a model checkpoint and return the iteration.""" |
|
if neox_args.deepspeed: |
|
load_optim_and_scheduler = ( |
|
not neox_args.no_load_optim |
|
) |
|
if neox_args.finetune: |
|
load_optim_and_scheduler = False |
|
if iteration is not None: |
|
tag = get_checkpoint_tag(iteration) |
|
else: |
|
tag = None |
|
checkpoint_name, state_dict = model.load_checkpoint( |
|
neox_args.load, |
|
load_optimizer_states=load_optim_and_scheduler, |
|
load_lr_scheduler_states=load_optim_and_scheduler, |
|
load_module_only=not load_optim_and_scheduler, |
|
tag=tag, |
|
load_module_strict=neox_args.train_impl != "rm", |
|
) |
|
|
|
if checkpoint_name is None: |
|
|
|
|
|
if iteration is not None: |
|
available_checkpoints = sorted( |
|
[ |
|
int(i.name.replace("global_step", "")) |
|
for i in Path(neox_args.load).glob("global_step*") |
|
] |
|
) |
|
raise ValueError( |
|
f"Unable to load checkpoint for iteration {iteration}. \nAvailable iterations: {pformat(available_checkpoints)}" |
|
) |
|
if mpu.get_data_parallel_rank() == 0: |
|
print("Unable to load checkpoint.") |
|
|
|
return 0 |
|
else: |
|
raise ValueError("Must be using deepspeed to use neox") |
|
|
|
|
|
if neox_args.finetune: |
|
iteration = 0 |
|
else: |
|
if "iteration" in state_dict: |
|
iteration = state_dict["iteration"] |
|
else: |
|
iteration = state_dict.get( |
|
"total_iters" |
|
) |
|
if iteration is None: |
|
raise ValueError( |
|
f"Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting" |
|
) |
|
|
|
|
|
if "args" in state_dict: |
|
checkpoint_args = state_dict["args"] |
|
check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args) |
|
print_rank_0( |
|
" > validated currently set args with arguments in the checkpoint ..." |
|
) |
|
else: |
|
print_rank_0(" > could not find arguments in the checkpoint for validation...") |
|
|
|
|
|
if neox_args.checkpoint_validation_with_forward_pass: |
|
if "checkpoint_validation_logits" in state_dict: |
|
check_forward_pass( |
|
neox_args=neox_args, |
|
model=model, |
|
checkpoint_logits=state_dict["checkpoint_validation_logits"], |
|
inference=inference, |
|
) |
|
print_rank_0(" > validated loaded checkpoint with forward pass ...") |
|
else: |
|
if mpu.get_data_parallel_rank() == 0: |
|
print( |
|
" > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}".format( |
|
checkpoint_name |
|
) |
|
) |
|
|
|
|
|
if not neox_args.finetune and not neox_args.no_load_rng: |
|
try: |
|
random.setstate(state_dict["random_rng_state"]) |
|
np.random.set_state(state_dict["np_rng_state"]) |
|
torch.set_rng_state(state_dict["torch_rng_state"]) |
|
torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) |
|
mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"]) |
|
except KeyError: |
|
print_rank_0( |
|
"Unable to load optimizer from checkpoint {}. " |
|
"Specify --no-load-rng or --finetune to prevent " |
|
"attempting to load the optimizer state, " |
|
"exiting ...".format(checkpoint_name) |
|
) |
|
sys.exit() |
|
|
|
torch.distributed.barrier() |
|
if mpu.get_data_parallel_rank() == 0: |
|
print(" successfully loaded {}".format(checkpoint_name)) |
|
|
|
return iteration |
|
|