CHEMISTral7Bv0.3 / finetune /wrapped_model.py
Clemspace's picture
Initial model upload
cb9e677
import functools
import json
import logging
import math
from pathlib import Path
from typing import Callable, Union
import safetensors
import torch
import torch.distributed.fsdp.wrap as torch_wrap
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from model.args import ModelArgs, MoeArgs
from model.transformer import Transformer, TransformerBlock
from .args import LoraArgs
from .checkpointing import Checkpointer
from .distributed import (
get_rank,
get_world_size,
)
logger = logging.getLogger(__name__)
def main_logger_info(message: str) -> None:
if get_rank() == 0:
logger.info(message)
def get_fsdp_policy(is_lora: bool) -> Callable[[torch.nn.Module], bool]:
"""
This function instantiates the FSDP wrap policy.
- Each Transformers block becomes it's own FSDP group so that only a single Transformer block is sharded at a time
- If LoRA is enabled, we additionally create seperate FSDP sub-groups for every trainable and non-trainable parameter group
since this is a requirement for mixed requires_grad=True/False training. See: https://pytorch.org/docs/stable/fsdp.html
"""
# Each transformer block becomes a FSDP group, each being sharded seperately
transformer_block_wrap_policy = functools.partial(
torch_wrap.transformer_auto_wrap_policy,
transformer_layer_cls=(TransformerBlock,),
)
if not is_lora:
return transformer_block_wrap_policy
def fsdp_lora_policy_fn(module):
return all(p.requires_grad for p in module.parameters())
# For LoRA training, trainable and non-trainable parameters need to be put into
# different FSDP groups
fsdp_lora_policy = functools.partial(
torch_wrap.lambda_auto_wrap_policy, lambda_fn=fsdp_lora_policy_fn
)
policies = [fsdp_lora_policy, transformer_block_wrap_policy]
return functools.partial(torch_wrap._or_policy, policies=policies)
def log_train_params(model: Union[torch.nn.Module, FullyShardedDataParallel]):
world_size = get_world_size()
num_params = world_size * sum(p.numel() for p in model.parameters())
num_train_params = world_size * sum(
p.numel() for p in model.parameters() if p.requires_grad
)
main_logger_info(
f"{num_train_params:,.0f} out of {num_params:,.0f} parameter are finetuned ({num_train_params / num_params * 100:.2f}%)."
)
def initialize_lora_parameters(model: torch.nn.Module, param_dtype: torch.dtype):
"""
Initialize LoRA layers with Kaiming uniform and zeros.
See original paper for more info: https://arxiv.org/abs/2106.09685 and
original github repo: https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L122
"""
for m_name, module in model.named_modules():
if all(p.is_meta for p in module.parameters()):
for p_name, param in module.named_parameters():
module._parameters[p_name] = torch.nn.Parameter(
torch.empty_like(param, device="cpu", dtype=param_dtype)
)
param = module._parameters[p_name]
if m_name.split(".")[-1] == "lora_A":
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
elif m_name.split(".")[-1] == "lora_B":
torch.nn.init.zeros_(param)
else:
raise ValueError(
"Only Lora layers should be randomely initialized."
)
def load_model(
folder: Path,
lora: LoraArgs,
checkpoint: bool,
param_dtype: torch.dtype,
) -> FullyShardedDataParallel:
with open(folder / "params.json", "r") as f:
args = json.loads(f.read())
model_args = ModelArgs(
lora=lora,
dim=args["dim"],
n_layers=args["n_layers"],
head_dim=args["head_dim"],
hidden_dim=args["hidden_dim"],
n_heads=args["n_heads"],
n_kv_heads=args["n_kv_heads"],
norm_eps=args["norm_eps"],
vocab_size=args["vocab_size"],
)
if model_args.vocab_size == 32000:
raise ValueError(
f"Fine-tuning is not supported for older model versions with vocab_size 32000. Make sure to extend your model to vocab_size=32768 using `python -m utils.extend_model_vocab --original_model_ckpt {folder} --extended_model_ckpt {folder}_extended`."
)
assert (
model_args.vocab_size >= 32768
), "Make sure to use a model with a vocab size of at least 32768"
if args.get("rope_theta") is not None:
model_args.rope_theta = args["rope_theta"]
if args.get("moe") is not None:
model_args.moe = MoeArgs(**args["moe"])
with torch.device("meta"):
model = Transformer(args=model_args, checkpoint=checkpoint)
if get_rank() == 0:
state_dict = load_state_dict(folder, dtype=param_dtype)
model.load_state_dict(state_dict, assign=True) # type: ignore
logger.info("Loaded model on cpu!")
if lora.enable:
logger.info("Initializing lora layers ...")
# initialize LoRA layers
initialize_lora_parameters(model, param_dtype)
assert not any(
p.is_meta for p in model.parameters()
), "All parameters should be intialized by now"
assert all(
p.dtype == param_dtype for p in model.parameters()
), f"All parameters should be on {param_dtype}"
logger.info("Finished initialization!")
param_init_fn = None
else:
def param_init_fn(m):
m.to_empty(device=torch.cuda.current_device(), recurse=False)
m.to(param_dtype)
assert all(
p.is_meta for p in model.parameters()
), "All parameters should be on meta"
torch.distributed.barrier()
# only finetune LoRA parameters and freeze before wrapping
if lora.enable:
for name, param in model.named_parameters():
if "lora" in name:
param.requires_grad = True
else:
param.requires_grad = False
auto_wrap_policy = get_fsdp_policy(model_args.lora.enable)
main_logger_info(f"Sharding model over {get_world_size()} GPUs ...")
wrapped_model = FullyShardedDataParallel(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
device_id=torch.cuda.current_device(),
sync_module_states=True,
param_init_fn=param_init_fn,
)
main_logger_info("Model sharded!")
log_train_params(wrapped_model)
return wrapped_model
@torch.no_grad()
def load_state_dict(path: Path, dtype: torch.dtype):
assert path.is_dir(), path
this_safetensors_path = Checkpointer.consolidated_path(path, use_safetensors=True)
this_torch_path = Checkpointer.consolidated_path(path, use_safetensors=False)
assert (
this_safetensors_path.exists() or this_torch_path.exists()
), f"Either {this_safetensors_path} or {this_torch_path} must exist."
assert not (
this_safetensors_path.exists() and this_torch_path.exists()
), f"Only one of {this_safetensors_path} or {this_torch_path} should exist."
if this_safetensors_path.exists():
logger.info(f"Reloading model from {this_safetensors_path} ...")
model_state_dict = safetensors.torch.load_file(this_safetensors_path)
else:
logger.info(f"Reloading model from {this_torch_path} ...")
model_state_dict = torch.load(this_torch_path)
logger.info(f"Converting model to dtype {dtype} ...")
for k, v in model_state_dict.items():
model_state_dict[k] = v.to(dtype)
return model_state_dict