NEOX / tools /ckpts /convert_raw_llama_weights_to_neox.py
akswelh's picture
Upload 251 files
d90b3a8 verified
# Copyright (c) 2024, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import torch
import json
import math
import tqdm.auto as tqdm
INTERMEDIATE_SIZE_MAP = {
"7B": 11008,
"13B": 13824,
"30B": 17920,
"34B": 22016,
"65B": 22016,
"70B": 28672,
"mistral-7B-v0.1": 14336,
}
NUM_SHARDS = {
"7B": 1,
"13B": 2,
"30B": 4,
"34B": 4,
"65B": 8,
"70B": 8,
"mistral-7B-v0.1": 1,
}
def compute_intermediate_size(n):
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_file(text, path):
with open(path, "w") as f:
f.write(text)
def convert_model_pipeline(
output_base_path, input_base_path, model_size: str, num_output_shards: int
):
assert model_size in NUM_SHARDS
model_path = os.path.join(output_base_path, "global_step0")
os.makedirs(model_path, exist_ok=True)
write_file("global_step0", os.path.join(output_base_path, "latest"))
params = read_json(os.path.join(input_base_path, "params.json"))
num_input_shards = NUM_SHARDS[model_size]
num_layers = params["n_layers"]
num_heads = params["n_heads"]
if "n_kv_heads" in params:
num_kv_heads = params["n_kv_heads"]
else:
num_kv_heads = num_heads
num_kv_heads_per_input_shard = num_kv_heads // num_input_shards
num_heads_per_input_shard = num_heads // num_input_shards
num_heads_per_output_shard = num_heads // num_output_shards
num_kv_heads_per_output_shard = num_kv_heads // num_output_shards
hidden_size = params["dim"]
dims_per_head = hidden_size // num_heads
# base = 10000.0
# inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
def permute_rotary(w):
if w.shape == (num_heads, dims_per_head, hidden_size):
N_HEADS = num_heads
elif w.shape == (num_kv_heads, dims_per_head, hidden_size):
N_HEADS = num_kv_heads
else:
assert False
return (
w.view(N_HEADS, dims_per_head // 2, 2, hidden_size)
.transpose(1, 2)
.reshape(N_HEADS, dims_per_head, hidden_size)
)
pbar = tqdm.tqdm(total=num_input_shards + num_layers + 3)
pbar.set_description(f"Loading shard")
loaded = []
for i in range(num_input_shards):
loaded.append(
torch.load(
os.path.join(input_base_path, f"consolidated.{i:02d}.pth"),
map_location="cpu",
)
)
pbar.set_description(f"Loaded shard {i}/{num_input_shards}")
pbar.update(1)
helper = Helper(
loaded=loaded,
model_path=model_path,
num_output_shards=num_output_shards,
model_size=model_size,
pipeline_parallel=False,
)
sequential_cache = [{} for _ in range(num_output_shards)]
# Embedding in
embeddings_in = torch.cat(
[
loaded[rank]["tok_embeddings.weight"].cpu()
for rank in range(num_input_shards)
],
dim=1,
)
print(embeddings_in.shape)
helper.save_shards(
{"word_embeddings.weight": helper.shard(embeddings_in, dim=0)}, layer_i=0
)
helper.del_loaded("tok_embeddings.weight")
pbar.set_description(f"Saved embeddings")
pbar.update(1)
# Norms
helper.save_duplicates(
{"norm.scale": loaded[0]["norm.weight"]}, layer_i=num_layers + 3
)
helper.del_loaded("norm.weight")
pbar.set_description(f"Saved final norm")
pbar.update(1)
# Embedding out
embeddings_out = torch.cat(
[loaded[rank]["output.weight"].cpu() for rank in range(num_input_shards)], dim=0
)
helper.save_shards(
{"final_linear.weight": helper.shard(embeddings_out, dim=0)},
layer_i=num_layers + 4,
)
helper.del_loaded("output.weight")
pbar.set_description(f"Saved out embeddings")
pbar.update(1)
# Layers
for layer_i in range(num_layers):
# Linear
attn_wo = helper.shard(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.attention.wo.weight"]
for rank in range(num_input_shards)
],
dim=1,
),
dim=1,
)
mlp_w1 = helper.shard(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.feed_forward.w1.weight"]
for rank in range(num_input_shards)
],
dim=0,
),
dim=0,
)
mlp_w2 = helper.shard(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.feed_forward.w2.weight"]
for rank in range(num_input_shards)
],
dim=1,
),
dim=1,
)
mlp_w3 = helper.shard(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.feed_forward.w3.weight"]
for rank in range(num_input_shards)
],
dim=0,
),
dim=0,
)
helper.del_loaded(f"layers.{layer_i}.attention.wo.weight")
helper.del_loaded(f"layers.{layer_i}.feed_forward.w1.weight")
helper.del_loaded(f"layers.{layer_i}.feed_forward.w2.weight")
helper.del_loaded(f"layers.{layer_i}.feed_forward.w3.weight")
# Attention
w_q = permute_rotary(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.attention.wq.weight"].view(
num_heads_per_input_shard, dims_per_head, hidden_size
)
for rank in range(num_input_shards)
],
dim=0,
)
)
w_k = permute_rotary(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.attention.wk.weight"].view(
num_kv_heads_per_input_shard, dims_per_head, hidden_size
)
for rank in range(num_input_shards)
],
dim=0,
)
).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size)
w_v = torch.cat(
[
loaded[rank][f"layers.{layer_i}.attention.wv.weight"].view(
num_kv_heads_per_input_shard, dims_per_head, hidden_size
)
for rank in range(num_input_shards)
],
dim=0,
).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size)
sharded_qkv = torch.cat(
[
helper.shard(
w_q, dim=0
), # num_output_shards, num_heads_per_output_shard, dims_per_head, hidden_size
helper.shard(w_k, dim=0),
helper.shard(w_v, dim=0),
],
dim=2,
) # num_output_shards, num_heads_per_output_shard, QKV=3, dims_per_head, hidden_size
sharded_qkv = sharded_qkv.view(
num_output_shards,
num_heads_per_output_shard * dims_per_head
+ 2 * num_kv_heads_per_output_shard * dims_per_head,
hidden_size,
)
helper.del_loaded(f"layers.{layer_i}.attention.wq.weight")
helper.del_loaded(f"layers.{layer_i}.attention.wk.weight")
helper.del_loaded(f"layers.{layer_i}.attention.wv.weight")
# Duplicated
input_layernorm = loaded[0][f"layers.{layer_i}.attention_norm.weight"]
post_attention_layernorm = loaded[0][f"layers.{layer_i}.ffn_norm.weight"]
helper.del_loaded(f"layers.{layer_i}.attention_norm.weight")
helper.del_loaded(f"layers.{layer_i}.ffn_norm.weight")
for out_rank in range(num_output_shards):
helper.save(
{
"attention.query_key_value.weight": sharded_qkv[out_rank],
# Sharded layers
"attention.dense.weight": attn_wo[out_rank].clone(),
"mlp.w1.weight": mlp_w1[out_rank].clone(),
"mlp.w2.weight": mlp_w2[out_rank].clone(),
"mlp.w3.weight": mlp_w3[out_rank].clone(),
# Duplicated layers
"input_layernorm.scale": input_layernorm,
"post_attention_layernorm.scale": post_attention_layernorm,
},
layer_i=layer_i + 2,
rank=out_rank,
)
pbar.set_description(f"Saved layer {layer_i} / {num_layers}")
pbar.update(1)
model_state = {
"dp_world_size": 1,
"mp_world_size": num_output_shards,
"module": {},
"optimizer": {},
"global_steps": 1,
"skipped_steps": 1,
"iteration": 1,
}
for rank in range(num_output_shards):
torch.save(
model_state, os.path.join(model_path, f"mp_rank_{rank:02d}_model_states.pt")
)
pbar.set_description("Done.")
def convert_model_sequential(
output_base_path, input_base_path, model_size: str, num_output_shards: int
):
assert model_size in NUM_SHARDS
model_path = os.path.join(output_base_path, "global_step0")
os.makedirs(model_path, exist_ok=True)
write_file("global_step0", os.path.join(output_base_path, "latest"))
params = read_json(os.path.join(input_base_path, "params.json"))
num_input_shards = NUM_SHARDS[model_size]
num_layers = params["n_layers"]
num_heads = params["n_heads"]
if "n_kv_heads" in params:
num_kv_heads = params["n_kv_heads"]
else:
num_kv_heads = num_heads
num_kv_heads_per_input_shard = num_kv_heads // num_input_shards
num_heads_per_input_shard = num_heads // num_input_shards
num_heads_per_output_shard = num_heads // num_output_shards
num_kv_heads_per_output_shard = num_kv_heads // num_output_shards
hidden_size = params["dim"]
dims_per_head = hidden_size // num_heads
# base = 10000.0
# inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
def permute_rotary(w):
if w.shape == (num_heads, dims_per_head, hidden_size):
N_HEADS = num_heads
elif w.shape == (num_kv_heads, dims_per_head, hidden_size):
N_HEADS = num_kv_heads
else:
assert False
return (
w.view(N_HEADS, dims_per_head // 2, 2, hidden_size)
.transpose(1, 2)
.reshape(N_HEADS, dims_per_head, hidden_size)
)
pbar = tqdm.tqdm(total=num_input_shards + num_output_shards)
pbar.set_description(f"Loading shard")
loaded = []
for i in range(num_input_shards):
loaded.append(
torch.load(
os.path.join(input_base_path, f"consolidated.{i:02d}.pth"),
map_location="cpu",
)
)
pbar.set_description(f"Loaded shard {i}/{num_input_shards}")
pbar.update(1)
helper = Helper(
loaded=loaded,
model_path=model_path,
num_output_shards=num_output_shards,
model_size=model_size,
pipeline_parallel=False,
)
# Embedding in
embeddings_in = torch.cat(
[
loaded[rank]["tok_embeddings.weight"].cpu()
for rank in range(num_input_shards)
],
dim=1,
)
helper.add_sequential_shard(
{"word_embeddings.weight": helper.shard(embeddings_in, dim=0)}, layer_i=0
)
helper.del_loaded("tok_embeddings.weight")
# Norms
helper.add_sequential_duplicates(
{"norm.scale": loaded[0]["norm.weight"]}, layer_i=num_layers + 3
)
helper.del_loaded("norm.weight")
# Embedding out
embeddings_out = torch.cat(
[loaded[rank]["output.weight"].cpu() for rank in range(num_input_shards)], dim=0
)
helper.add_sequential_shard(
{"final_linear.weight": helper.shard(embeddings_out, dim=0)},
layer_i=num_layers + 4,
)
helper.del_loaded("output.weight")
# Layers
for layer_i in range(num_layers):
# Linear
attn_wo = helper.shard(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.attention.wo.weight"]
for rank in range(num_input_shards)
],
dim=1,
),
dim=1,
)
mlp_w1 = helper.shard(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.feed_forward.w1.weight"]
for rank in range(num_input_shards)
],
dim=0,
),
dim=0,
)
mlp_w2 = helper.shard(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.feed_forward.w2.weight"]
for rank in range(num_input_shards)
],
dim=1,
),
dim=1,
)
mlp_w3 = helper.shard(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.feed_forward.w3.weight"]
for rank in range(num_input_shards)
],
dim=0,
),
dim=0,
)
helper.del_loaded(f"layers.{layer_i}.attention.wo.weight")
helper.del_loaded(f"layers.{layer_i}.feed_forward.w1.weight")
helper.del_loaded(f"layers.{layer_i}.feed_forward.w2.weight")
helper.del_loaded(f"layers.{layer_i}.feed_forward.w3.weight")
# Attention
w_q = permute_rotary(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.attention.wq.weight"].view(
num_heads_per_input_shard, dims_per_head, hidden_size
)
for rank in range(num_input_shards)
],
dim=0,
)
)
w_k = permute_rotary(
torch.cat(
[
loaded[rank][f"layers.{layer_i}.attention.wk.weight"].view(
num_kv_heads_per_input_shard, dims_per_head, hidden_size
)
for rank in range(num_input_shards)
],
dim=0,
)
).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size)
w_v = torch.cat(
[
loaded[rank][f"layers.{layer_i}.attention.wv.weight"].view(
num_kv_heads_per_input_shard, dims_per_head, hidden_size
)
for rank in range(num_input_shards)
],
dim=0,
).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size)
sharded_qkv = torch.cat(
[
helper.shard(
w_q, dim=0
), # num_output_shards, num_heads_per_output_shard, dims_per_head, hidden_size
helper.shard(w_k, dim=0),
helper.shard(w_v, dim=0),
],
dim=2,
) # num_output_shards, num_heads_per_output_shard, QKV=3, dims_per_head, hidden_size
sharded_qkv = sharded_qkv.view(
num_output_shards,
num_heads_per_output_shard * dims_per_head
+ 2 * num_kv_heads_per_output_shard * dims_per_head,
hidden_size,
)
helper.del_loaded(f"layers.{layer_i}.attention.wq.weight")
helper.del_loaded(f"layers.{layer_i}.attention.wk.weight")
helper.del_loaded(f"layers.{layer_i}.attention.wv.weight")
# Duplicated
input_layernorm = loaded[0][f"layers.{layer_i}.attention_norm.weight"]
post_attention_layernorm = loaded[0][f"layers.{layer_i}.ffn_norm.weight"]
helper.del_loaded(f"layers.{layer_i}.attention_norm.weight")
helper.del_loaded(f"layers.{layer_i}.ffn_norm.weight")
for out_rank in range(num_output_shards):
helper.add_sequential(
{
"attention.query_key_value.weight": sharded_qkv[out_rank],
# Sharded layers
"attention.dense.weight": attn_wo[out_rank].clone(),
"mlp.w1.weight": mlp_w1[out_rank].clone(),
"mlp.w2.weight": mlp_w2[out_rank].clone(),
"mlp.w3.weight": mlp_w3[out_rank].clone(),
# Duplicated layers
"input_layernorm.scale": input_layernorm,
"post_attention_layernorm.scale": post_attention_layernorm,
},
layer_i=layer_i + 2,
rank=out_rank,
)
for rank in range(num_output_shards):
model_state = {
"dp_world_size": 1,
"mp_world_size": num_output_shards,
"module": helper.sequential_cache[rank],
"optimizer": {},
"global_steps": 1,
"skipped_steps": 1,
"iteration": 1,
}
torch.save(
model_state, os.path.join(model_path, f"mp_rank_{rank:02d}_model_states.pt")
)
pbar.set_description(f"Saved shard {rank}")
pbar.update(1)
pbar.set_description("Done.")
class Helper:
def __init__(
self, loaded, model_size, num_output_shards, model_path, pipeline_parallel
):
self.loaded = loaded
self.model_size = model_size
self.num_output_shards = num_output_shards
self.model_path = model_path
self.pipeline_parallel = pipeline_parallel
self.sequential_cache = [{} for _ in range(num_output_shards)]
def del_loaded(self, key: str):
# Remove from memory as we go along
for loaded_shared in self.loaded:
del loaded_shared[key]
def save_shards(self, dictionary, layer_i: int):
for k, v in dictionary.items():
assert v.shape[0] == self.num_output_shards
for rank in range(self.num_output_shards):
torch.save(
{k: v[rank].clone() for k, v in dictionary.items()},
self.save_path(layer_i=layer_i, rank=rank),
)
def save_duplicates(self, dictionary, layer_i: int):
for rank in range(self.num_output_shards):
torch.save(
{k: v.clone() for k, v in dictionary.items()},
self.save_path(layer_i=layer_i, rank=rank),
)
def save(self, obj, layer_i, rank):
torch.save(obj, self.save_path(layer_i=layer_i, rank=rank))
def shard(self, x, dim):
x_shape = list(x.shape)
assert x_shape[dim] % self.num_output_shards == 0
new_x_shape = (
x_shape[:dim]
+ [self.num_output_shards, x_shape[dim] // self.num_output_shards]
+ x_shape[dim + 1 :]
)
x = x.view(*new_x_shape)
return torch.movedim(x, 0, dim)
def save_path(self, layer_i, rank):
return os.path.join(
self.model_path, f"layer_{layer_i:02d}-model_{rank:02d}-model_states.pt"
)
def add_sequential_shard(self, dictionary, layer_i):
assert not self.pipeline_parallel
for k, v in dictionary.items():
for rank in range(self.num_output_shards):
self.sequential_cache[rank][f"sequential.{layer_i}.{k}"] = v[
rank
].clone()
def add_sequential_duplicates(self, dictionary, layer_i):
assert not self.pipeline_parallel
for k, v in dictionary.items():
for rank in range(self.num_output_shards):
self.sequential_cache[rank][f"sequential.{layer_i}.{k}"] = v.clone()
def add_sequential(self, dictionary, layer_i, rank):
assert not self.pipeline_parallel
for k, v in dictionary.items():
self.sequential_cache[rank][f"sequential.{layer_i}.{k}"] = v.clone()
def main():
parser = argparse.ArgumentParser(
description="Convert raw LLaMA or Mistral checkpoints to GPT-NeoX format."
)
parser.add_argument(
"--input_dir",
help="Location of parent directory, which contains tokenizer.model and model weights subfolders",
)
parser.add_argument(
"--model_size",
choices=["7B", "mistral-7B-v0.1", "13B", "30B", "34B", "65B", "tokenizer_only"],
)
parser.add_argument(
"--output_dir",
help="Location to write GPT-NeoX model",
)
parser.add_argument(
"--num_output_shards",
type=int,
default=1,
)
parser.add_argument(
"--pipeline_parallel",
action="store_true",
help="Only use if PP>1",
)
args = parser.parse_args()
if args.pipeline_parallel:
print("parallel")
convert_model_pipeline(
output_base_path=args.output_dir,
input_base_path=os.path.join(args.input_dir, args.model_size),
model_size=args.model_size,
num_output_shards=args.num_output_shards,
)
else:
print("sequential")
convert_model_sequential(
output_base_path=args.output_dir,
input_base_path=os.path.join(args.input_dir, args.model_size),
model_size=args.model_size,
num_output_shards=args.num_output_shards,
)
if __name__ == "__main__":
main()