NEOX / tools /ckpts /convert_hf_llama_to_neox.py
akswelh's picture
Upload 251 files
d90b3a8 verified
import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import tqdm
def convert_model(hf_state_dict, hf_config, tp_ranks):
conv_state_dicts = [{} for _ in range(tp_ranks)]
# get embeddings...
for i, chunk in enumerate(
torch.chunk(hf_state_dict["model.embed_tokens.weight"], tp_ranks, dim=0)
):
conv_state_dicts[i][
"sequential.0.word_embeddings.weight"
] = chunk.clone().detach()
print(
"model.embed_tokens.weight",
hf_state_dict["model.embed_tokens.weight"].shape,
"sequential.0.word_embeddings.weight",
conv_state_dicts[0]["sequential.0.word_embeddings.weight"].shape,
)
# Get config data...
num_kv_heads = hf_config.num_key_value_heads
num_q_heads = hf_config.num_attention_heads
head_dim = hf_config.hidden_size // num_q_heads
# do layers...
for layer_num in tqdm.tqdm(range(model.model.config.num_hidden_layers)):
# --- attention ---
# Output first since it's a simple row parallel...
for i, chunk in enumerate(
torch.chunk(
hf_state_dict[f"model.layers.{layer_num}.self_attn.o_proj.weight"],
tp_ranks,
dim=1,
)
):
conv_state_dicts[i][
f"sequential.{layer_num+2}.attention.dense.weight"
] = chunk.clone().detach()
print(
f"model.layers.{layer_num}.self_attn.o_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.self_attn.o_proj.weight"].shape,
f"sequential.{layer_num+2}.attention.dense.weight",
conv_state_dicts[0][
f"sequential.{layer_num+2}.attention.dense.weight"
].shape,
)
# Now for attention...
# Split into heads...
q = hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"]
k = hf_state_dict[f"model.layers.{layer_num}.self_attn.k_proj.weight"]
v = hf_state_dict[f"model.layers.{layer_num}.self_attn.v_proj.weight"]
# The GQA code splits the heads by the num_q_heads so we also do that
# here to ensure it matches...
q = q.view(num_q_heads, -1, q.shape[-1])
k = k.view(num_q_heads, -1, q.shape[-1])
v = v.view(num_q_heads, -1, q.shape[-1])
# Chunk for tensor parallelism...
for i, q_chunk, k_chunk, v_chunk in zip(
range(tp_ranks),
torch.chunk(q, tp_ranks, dim=0),
torch.chunk(k, tp_ranks, dim=0),
torch.chunk(v, tp_ranks, dim=0),
):
# Need to join the heads across q, k, v...
conv_state_dicts[i][
f"sequential.{layer_num+2}.attention.query_key_value.weight"
] = (
torch.cat([q_chunk, k_chunk, v_chunk], dim=1)
.view(-1, q.shape[-1])
.clone()
.detach()
)
print(
f"model.layers.{layer_num}.self_attn.(q/k/v)_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"].shape,
hf_state_dict[f"model.layers.{layer_num}.self_attn.k_proj.weight"].shape,
hf_state_dict[f"model.layers.{layer_num}.self_attn.v_proj.weight"].shape,
f"sequential.{layer_num+2}.attention.query_key_value.weight",
conv_state_dicts[0][
f"sequential.{layer_num+2}.attention.query_key_value.weight"
].shape,
)
# --- mlp ---
# Do SwiGLU weights...
# w1...
for i, (w1, w3) in enumerate(
zip(
torch.chunk(
hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"],
tp_ranks,
dim=0,
),
torch.chunk(
hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"],
tp_ranks,
dim=0,
),
)
):
conv_state_dicts[i][
f"sequential.{layer_num+2}.mlp.linear1.weight"
] = torch.cat([w3.clone().detach(), w1.clone().detach()], dim=0)
print(
f"model.layers.{layer_num}.mlp.gate_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"].shape,
f"model.layers.{layer_num}.mlp.up_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"].shape,
f"sequential.{layer_num+2}.mlp.w3.weight",
conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.linear1.weight"].shape,
)
# w2 (output)...
for i, chunk in enumerate(
torch.chunk(
hf_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"],
tp_ranks,
dim=1,
)
):
conv_state_dicts[i][
f"sequential.{layer_num+2}.mlp.linear2.weight"
] = chunk.clone().detach()
print(
f"model.layers.{layer_num}.mlp.down_proj.weight",
hf_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"].shape,
f"sequential.{layer_num+2}.mlp.linear2.weight",
conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.linear2.weight"].shape,
)
# --- norm ---
for i in range(tp_ranks):
conv_state_dicts[i][f"sequential.{layer_num+2}.input_layernorm.scale"] = (
hf_state_dict[f"model.layers.{layer_num}.input_layernorm.weight"]
.clone()
.detach()
)
conv_state_dicts[i][
f"sequential.{layer_num+2}.post_attention_layernorm.scale"
] = (
hf_state_dict[
f"model.layers.{layer_num}.post_attention_layernorm.weight"
]
.clone()
.detach()
)
# Get final ln/linear....
index = model.model.config.num_hidden_layers + 3
for i in range(tp_ranks):
conv_state_dicts[i][f"sequential.{index}.norm.scale"] = (
hf_state_dict["model.norm.weight"].clone().detach()
)
index += 1
# do output...
for i, chunk in enumerate(
torch.chunk(hf_state_dict["lm_head.weight"], tp_ranks, dim=0)
):
conv_state_dicts[i][
f"sequential.{index}.final_linear.weight"
] = chunk.clone().detach()
print(
"lm_head.weight",
hf_state_dict["lm_head.weight"].shape,
f"sequential.{index}.final_linear.weight",
conv_state_dicts[0][f"sequential.{index}.final_linear.weight"].shape,
)
return conv_state_dicts
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tp", type=int, default=1, help="Number of tensor parallelism ranks"
)
parser.add_argument(
"--pp", type=int, default=0, help="Number of pipeline parallelism stages"
)
parser.add_argument("--model", type=str, default="gpt2", help="HF model name")
parser.add_argument(
"--model_path", type=str, default=None, help="Path to save model"
)
args = parser.parse_args()
assert args.pp == 0, "Pipeline parallelism not supported yet"
tokenizer = AutoTokenizer.from_pretrained(args.model).save_pretrained(
args.model_path + "/tokenizer"
)
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto")
state_dict = model.state_dict()
for key in state_dict.keys():
print(key, state_dict[key].shape)
os.makedirs(args.model_path, exist_ok=True)
# Setup model directory...
os.makedirs(f"{args.model_path}/0", exist_ok=True)
# Save the latest file so neox can figure out where to grab the weights...
with open(f"{args.model_path}/latest", "w") as f:
f.write("0")
# Convert the model...
tp_state_dicts = convert_model(state_dict, model.model.config, args.tp)
for i in range(args.tp):
torch.save(
{
"dp_world_size": 1,
"mp_world_size": args.tp,
"optimizer": {},
"global_steps": 1,
"skipped_steps": 1,
"iteration": 1,
"module": tp_state_dicts[i],
},
f"{args.model_path}/0/mp_rank_{i:02d}_model_states.pt",
)