|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import torch |
|
import json |
|
import math |
|
import tqdm.auto as tqdm |
|
|
|
|
|
INTERMEDIATE_SIZE_MAP = { |
|
"7B": 11008, |
|
"13B": 13824, |
|
"30B": 17920, |
|
"34B": 22016, |
|
"65B": 22016, |
|
"70B": 28672, |
|
"mistral-7B-v0.1": 14336, |
|
} |
|
NUM_SHARDS = { |
|
"7B": 1, |
|
"13B": 2, |
|
"30B": 4, |
|
"34B": 4, |
|
"65B": 8, |
|
"70B": 8, |
|
"mistral-7B-v0.1": 1, |
|
} |
|
|
|
|
|
def compute_intermediate_size(n): |
|
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 |
|
|
|
|
|
def read_json(path): |
|
with open(path, "r") as f: |
|
return json.load(f) |
|
|
|
|
|
def write_json(text, path): |
|
with open(path, "w") as f: |
|
json.dump(text, f) |
|
|
|
|
|
def write_file(text, path): |
|
with open(path, "w") as f: |
|
f.write(text) |
|
|
|
|
|
def convert_model_pipeline( |
|
output_base_path, input_base_path, model_size: str, num_output_shards: int |
|
): |
|
assert model_size in NUM_SHARDS |
|
|
|
model_path = os.path.join(output_base_path, "global_step0") |
|
os.makedirs(model_path, exist_ok=True) |
|
write_file("global_step0", os.path.join(output_base_path, "latest")) |
|
|
|
params = read_json(os.path.join(input_base_path, "params.json")) |
|
num_input_shards = NUM_SHARDS[model_size] |
|
num_layers = params["n_layers"] |
|
num_heads = params["n_heads"] |
|
if "n_kv_heads" in params: |
|
num_kv_heads = params["n_kv_heads"] |
|
else: |
|
num_kv_heads = num_heads |
|
num_kv_heads_per_input_shard = num_kv_heads // num_input_shards |
|
num_heads_per_input_shard = num_heads // num_input_shards |
|
num_heads_per_output_shard = num_heads // num_output_shards |
|
num_kv_heads_per_output_shard = num_kv_heads // num_output_shards |
|
hidden_size = params["dim"] |
|
dims_per_head = hidden_size // num_heads |
|
|
|
|
|
|
|
def permute_rotary(w): |
|
if w.shape == (num_heads, dims_per_head, hidden_size): |
|
N_HEADS = num_heads |
|
elif w.shape == (num_kv_heads, dims_per_head, hidden_size): |
|
N_HEADS = num_kv_heads |
|
else: |
|
assert False |
|
return ( |
|
w.view(N_HEADS, dims_per_head // 2, 2, hidden_size) |
|
.transpose(1, 2) |
|
.reshape(N_HEADS, dims_per_head, hidden_size) |
|
) |
|
|
|
pbar = tqdm.tqdm(total=num_input_shards + num_layers + 3) |
|
|
|
pbar.set_description(f"Loading shard") |
|
loaded = [] |
|
for i in range(num_input_shards): |
|
loaded.append( |
|
torch.load( |
|
os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), |
|
map_location="cpu", |
|
) |
|
) |
|
pbar.set_description(f"Loaded shard {i}/{num_input_shards}") |
|
pbar.update(1) |
|
helper = Helper( |
|
loaded=loaded, |
|
model_path=model_path, |
|
num_output_shards=num_output_shards, |
|
model_size=model_size, |
|
pipeline_parallel=False, |
|
) |
|
|
|
sequential_cache = [{} for _ in range(num_output_shards)] |
|
|
|
|
|
embeddings_in = torch.cat( |
|
[ |
|
loaded[rank]["tok_embeddings.weight"].cpu() |
|
for rank in range(num_input_shards) |
|
], |
|
dim=1, |
|
) |
|
print(embeddings_in.shape) |
|
helper.save_shards( |
|
{"word_embeddings.weight": helper.shard(embeddings_in, dim=0)}, layer_i=0 |
|
) |
|
helper.del_loaded("tok_embeddings.weight") |
|
pbar.set_description(f"Saved embeddings") |
|
pbar.update(1) |
|
|
|
|
|
helper.save_duplicates( |
|
{"norm.scale": loaded[0]["norm.weight"]}, layer_i=num_layers + 3 |
|
) |
|
helper.del_loaded("norm.weight") |
|
pbar.set_description(f"Saved final norm") |
|
pbar.update(1) |
|
|
|
|
|
embeddings_out = torch.cat( |
|
[loaded[rank]["output.weight"].cpu() for rank in range(num_input_shards)], dim=0 |
|
) |
|
helper.save_shards( |
|
{"final_linear.weight": helper.shard(embeddings_out, dim=0)}, |
|
layer_i=num_layers + 4, |
|
) |
|
helper.del_loaded("output.weight") |
|
pbar.set_description(f"Saved out embeddings") |
|
pbar.update(1) |
|
|
|
|
|
for layer_i in range(num_layers): |
|
|
|
|
|
attn_wo = helper.shard( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.attention.wo.weight"] |
|
for rank in range(num_input_shards) |
|
], |
|
dim=1, |
|
), |
|
dim=1, |
|
) |
|
mlp_w1 = helper.shard( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.feed_forward.w1.weight"] |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
), |
|
dim=0, |
|
) |
|
mlp_w2 = helper.shard( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.feed_forward.w2.weight"] |
|
for rank in range(num_input_shards) |
|
], |
|
dim=1, |
|
), |
|
dim=1, |
|
) |
|
mlp_w3 = helper.shard( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.feed_forward.w3.weight"] |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
), |
|
dim=0, |
|
) |
|
helper.del_loaded(f"layers.{layer_i}.attention.wo.weight") |
|
helper.del_loaded(f"layers.{layer_i}.feed_forward.w1.weight") |
|
helper.del_loaded(f"layers.{layer_i}.feed_forward.w2.weight") |
|
helper.del_loaded(f"layers.{layer_i}.feed_forward.w3.weight") |
|
|
|
|
|
w_q = permute_rotary( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.attention.wq.weight"].view( |
|
num_heads_per_input_shard, dims_per_head, hidden_size |
|
) |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
) |
|
) |
|
w_k = permute_rotary( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.attention.wk.weight"].view( |
|
num_kv_heads_per_input_shard, dims_per_head, hidden_size |
|
) |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
) |
|
).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size) |
|
|
|
w_v = torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.attention.wv.weight"].view( |
|
num_kv_heads_per_input_shard, dims_per_head, hidden_size |
|
) |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size) |
|
|
|
sharded_qkv = torch.cat( |
|
[ |
|
helper.shard( |
|
w_q, dim=0 |
|
), |
|
helper.shard(w_k, dim=0), |
|
helper.shard(w_v, dim=0), |
|
], |
|
dim=2, |
|
) |
|
|
|
sharded_qkv = sharded_qkv.view( |
|
num_output_shards, |
|
num_heads_per_output_shard * dims_per_head |
|
+ 2 * num_kv_heads_per_output_shard * dims_per_head, |
|
hidden_size, |
|
) |
|
helper.del_loaded(f"layers.{layer_i}.attention.wq.weight") |
|
helper.del_loaded(f"layers.{layer_i}.attention.wk.weight") |
|
helper.del_loaded(f"layers.{layer_i}.attention.wv.weight") |
|
|
|
|
|
input_layernorm = loaded[0][f"layers.{layer_i}.attention_norm.weight"] |
|
post_attention_layernorm = loaded[0][f"layers.{layer_i}.ffn_norm.weight"] |
|
helper.del_loaded(f"layers.{layer_i}.attention_norm.weight") |
|
helper.del_loaded(f"layers.{layer_i}.ffn_norm.weight") |
|
|
|
for out_rank in range(num_output_shards): |
|
helper.save( |
|
{ |
|
"attention.query_key_value.weight": sharded_qkv[out_rank], |
|
|
|
"attention.dense.weight": attn_wo[out_rank].clone(), |
|
"mlp.w1.weight": mlp_w1[out_rank].clone(), |
|
"mlp.w2.weight": mlp_w2[out_rank].clone(), |
|
"mlp.w3.weight": mlp_w3[out_rank].clone(), |
|
|
|
"input_layernorm.scale": input_layernorm, |
|
"post_attention_layernorm.scale": post_attention_layernorm, |
|
}, |
|
layer_i=layer_i + 2, |
|
rank=out_rank, |
|
) |
|
|
|
pbar.set_description(f"Saved layer {layer_i} / {num_layers}") |
|
pbar.update(1) |
|
|
|
model_state = { |
|
"dp_world_size": 1, |
|
"mp_world_size": num_output_shards, |
|
"module": {}, |
|
"optimizer": {}, |
|
"global_steps": 1, |
|
"skipped_steps": 1, |
|
"iteration": 1, |
|
} |
|
for rank in range(num_output_shards): |
|
torch.save( |
|
model_state, os.path.join(model_path, f"mp_rank_{rank:02d}_model_states.pt") |
|
) |
|
pbar.set_description("Done.") |
|
|
|
|
|
def convert_model_sequential( |
|
output_base_path, input_base_path, model_size: str, num_output_shards: int |
|
): |
|
assert model_size in NUM_SHARDS |
|
|
|
model_path = os.path.join(output_base_path, "global_step0") |
|
os.makedirs(model_path, exist_ok=True) |
|
write_file("global_step0", os.path.join(output_base_path, "latest")) |
|
|
|
params = read_json(os.path.join(input_base_path, "params.json")) |
|
num_input_shards = NUM_SHARDS[model_size] |
|
num_layers = params["n_layers"] |
|
num_heads = params["n_heads"] |
|
if "n_kv_heads" in params: |
|
num_kv_heads = params["n_kv_heads"] |
|
else: |
|
num_kv_heads = num_heads |
|
num_kv_heads_per_input_shard = num_kv_heads // num_input_shards |
|
num_heads_per_input_shard = num_heads // num_input_shards |
|
num_heads_per_output_shard = num_heads // num_output_shards |
|
num_kv_heads_per_output_shard = num_kv_heads // num_output_shards |
|
hidden_size = params["dim"] |
|
dims_per_head = hidden_size // num_heads |
|
|
|
|
|
|
|
def permute_rotary(w): |
|
if w.shape == (num_heads, dims_per_head, hidden_size): |
|
N_HEADS = num_heads |
|
elif w.shape == (num_kv_heads, dims_per_head, hidden_size): |
|
N_HEADS = num_kv_heads |
|
else: |
|
assert False |
|
return ( |
|
w.view(N_HEADS, dims_per_head // 2, 2, hidden_size) |
|
.transpose(1, 2) |
|
.reshape(N_HEADS, dims_per_head, hidden_size) |
|
) |
|
|
|
pbar = tqdm.tqdm(total=num_input_shards + num_output_shards) |
|
|
|
pbar.set_description(f"Loading shard") |
|
loaded = [] |
|
for i in range(num_input_shards): |
|
loaded.append( |
|
torch.load( |
|
os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), |
|
map_location="cpu", |
|
) |
|
) |
|
pbar.set_description(f"Loaded shard {i}/{num_input_shards}") |
|
pbar.update(1) |
|
helper = Helper( |
|
loaded=loaded, |
|
model_path=model_path, |
|
num_output_shards=num_output_shards, |
|
model_size=model_size, |
|
pipeline_parallel=False, |
|
) |
|
|
|
|
|
embeddings_in = torch.cat( |
|
[ |
|
loaded[rank]["tok_embeddings.weight"].cpu() |
|
for rank in range(num_input_shards) |
|
], |
|
dim=1, |
|
) |
|
|
|
helper.add_sequential_shard( |
|
{"word_embeddings.weight": helper.shard(embeddings_in, dim=0)}, layer_i=0 |
|
) |
|
helper.del_loaded("tok_embeddings.weight") |
|
|
|
|
|
helper.add_sequential_duplicates( |
|
{"norm.scale": loaded[0]["norm.weight"]}, layer_i=num_layers + 3 |
|
) |
|
helper.del_loaded("norm.weight") |
|
|
|
|
|
embeddings_out = torch.cat( |
|
[loaded[rank]["output.weight"].cpu() for rank in range(num_input_shards)], dim=0 |
|
) |
|
helper.add_sequential_shard( |
|
{"final_linear.weight": helper.shard(embeddings_out, dim=0)}, |
|
layer_i=num_layers + 4, |
|
) |
|
helper.del_loaded("output.weight") |
|
|
|
|
|
for layer_i in range(num_layers): |
|
|
|
|
|
attn_wo = helper.shard( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.attention.wo.weight"] |
|
for rank in range(num_input_shards) |
|
], |
|
dim=1, |
|
), |
|
dim=1, |
|
) |
|
mlp_w1 = helper.shard( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.feed_forward.w1.weight"] |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
), |
|
dim=0, |
|
) |
|
mlp_w2 = helper.shard( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.feed_forward.w2.weight"] |
|
for rank in range(num_input_shards) |
|
], |
|
dim=1, |
|
), |
|
dim=1, |
|
) |
|
mlp_w3 = helper.shard( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.feed_forward.w3.weight"] |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
), |
|
dim=0, |
|
) |
|
helper.del_loaded(f"layers.{layer_i}.attention.wo.weight") |
|
helper.del_loaded(f"layers.{layer_i}.feed_forward.w1.weight") |
|
helper.del_loaded(f"layers.{layer_i}.feed_forward.w2.weight") |
|
helper.del_loaded(f"layers.{layer_i}.feed_forward.w3.weight") |
|
|
|
|
|
w_q = permute_rotary( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.attention.wq.weight"].view( |
|
num_heads_per_input_shard, dims_per_head, hidden_size |
|
) |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
) |
|
) |
|
|
|
w_k = permute_rotary( |
|
torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.attention.wk.weight"].view( |
|
num_kv_heads_per_input_shard, dims_per_head, hidden_size |
|
) |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
) |
|
).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size) |
|
|
|
w_v = torch.cat( |
|
[ |
|
loaded[rank][f"layers.{layer_i}.attention.wv.weight"].view( |
|
num_kv_heads_per_input_shard, dims_per_head, hidden_size |
|
) |
|
for rank in range(num_input_shards) |
|
], |
|
dim=0, |
|
).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size) |
|
|
|
sharded_qkv = torch.cat( |
|
[ |
|
helper.shard( |
|
w_q, dim=0 |
|
), |
|
helper.shard(w_k, dim=0), |
|
helper.shard(w_v, dim=0), |
|
], |
|
dim=2, |
|
) |
|
|
|
sharded_qkv = sharded_qkv.view( |
|
num_output_shards, |
|
num_heads_per_output_shard * dims_per_head |
|
+ 2 * num_kv_heads_per_output_shard * dims_per_head, |
|
hidden_size, |
|
) |
|
|
|
helper.del_loaded(f"layers.{layer_i}.attention.wq.weight") |
|
helper.del_loaded(f"layers.{layer_i}.attention.wk.weight") |
|
helper.del_loaded(f"layers.{layer_i}.attention.wv.weight") |
|
|
|
|
|
input_layernorm = loaded[0][f"layers.{layer_i}.attention_norm.weight"] |
|
post_attention_layernorm = loaded[0][f"layers.{layer_i}.ffn_norm.weight"] |
|
helper.del_loaded(f"layers.{layer_i}.attention_norm.weight") |
|
helper.del_loaded(f"layers.{layer_i}.ffn_norm.weight") |
|
|
|
for out_rank in range(num_output_shards): |
|
helper.add_sequential( |
|
{ |
|
"attention.query_key_value.weight": sharded_qkv[out_rank], |
|
|
|
"attention.dense.weight": attn_wo[out_rank].clone(), |
|
"mlp.w1.weight": mlp_w1[out_rank].clone(), |
|
"mlp.w2.weight": mlp_w2[out_rank].clone(), |
|
"mlp.w3.weight": mlp_w3[out_rank].clone(), |
|
|
|
"input_layernorm.scale": input_layernorm, |
|
"post_attention_layernorm.scale": post_attention_layernorm, |
|
}, |
|
layer_i=layer_i + 2, |
|
rank=out_rank, |
|
) |
|
|
|
for rank in range(num_output_shards): |
|
model_state = { |
|
"dp_world_size": 1, |
|
"mp_world_size": num_output_shards, |
|
"module": helper.sequential_cache[rank], |
|
"optimizer": {}, |
|
"global_steps": 1, |
|
"skipped_steps": 1, |
|
"iteration": 1, |
|
} |
|
torch.save( |
|
model_state, os.path.join(model_path, f"mp_rank_{rank:02d}_model_states.pt") |
|
) |
|
pbar.set_description(f"Saved shard {rank}") |
|
pbar.update(1) |
|
pbar.set_description("Done.") |
|
|
|
|
|
class Helper: |
|
def __init__( |
|
self, loaded, model_size, num_output_shards, model_path, pipeline_parallel |
|
): |
|
self.loaded = loaded |
|
self.model_size = model_size |
|
self.num_output_shards = num_output_shards |
|
self.model_path = model_path |
|
|
|
self.pipeline_parallel = pipeline_parallel |
|
self.sequential_cache = [{} for _ in range(num_output_shards)] |
|
|
|
def del_loaded(self, key: str): |
|
|
|
for loaded_shared in self.loaded: |
|
del loaded_shared[key] |
|
|
|
def save_shards(self, dictionary, layer_i: int): |
|
for k, v in dictionary.items(): |
|
assert v.shape[0] == self.num_output_shards |
|
for rank in range(self.num_output_shards): |
|
torch.save( |
|
{k: v[rank].clone() for k, v in dictionary.items()}, |
|
self.save_path(layer_i=layer_i, rank=rank), |
|
) |
|
|
|
def save_duplicates(self, dictionary, layer_i: int): |
|
for rank in range(self.num_output_shards): |
|
torch.save( |
|
{k: v.clone() for k, v in dictionary.items()}, |
|
self.save_path(layer_i=layer_i, rank=rank), |
|
) |
|
|
|
def save(self, obj, layer_i, rank): |
|
torch.save(obj, self.save_path(layer_i=layer_i, rank=rank)) |
|
|
|
def shard(self, x, dim): |
|
x_shape = list(x.shape) |
|
assert x_shape[dim] % self.num_output_shards == 0 |
|
new_x_shape = ( |
|
x_shape[:dim] |
|
+ [self.num_output_shards, x_shape[dim] // self.num_output_shards] |
|
+ x_shape[dim + 1 :] |
|
) |
|
x = x.view(*new_x_shape) |
|
return torch.movedim(x, 0, dim) |
|
|
|
def save_path(self, layer_i, rank): |
|
return os.path.join( |
|
self.model_path, f"layer_{layer_i:02d}-model_{rank:02d}-model_states.pt" |
|
) |
|
|
|
def add_sequential_shard(self, dictionary, layer_i): |
|
assert not self.pipeline_parallel |
|
for k, v in dictionary.items(): |
|
for rank in range(self.num_output_shards): |
|
self.sequential_cache[rank][f"sequential.{layer_i}.{k}"] = v[ |
|
rank |
|
].clone() |
|
|
|
def add_sequential_duplicates(self, dictionary, layer_i): |
|
assert not self.pipeline_parallel |
|
for k, v in dictionary.items(): |
|
for rank in range(self.num_output_shards): |
|
self.sequential_cache[rank][f"sequential.{layer_i}.{k}"] = v.clone() |
|
|
|
def add_sequential(self, dictionary, layer_i, rank): |
|
assert not self.pipeline_parallel |
|
for k, v in dictionary.items(): |
|
self.sequential_cache[rank][f"sequential.{layer_i}.{k}"] = v.clone() |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Convert raw LLaMA or Mistral checkpoints to GPT-NeoX format." |
|
) |
|
parser.add_argument( |
|
"--input_dir", |
|
help="Location of parent directory, which contains tokenizer.model and model weights subfolders", |
|
) |
|
parser.add_argument( |
|
"--model_size", |
|
choices=["7B", "mistral-7B-v0.1", "13B", "30B", "34B", "65B", "tokenizer_only"], |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
help="Location to write GPT-NeoX model", |
|
) |
|
parser.add_argument( |
|
"--num_output_shards", |
|
type=int, |
|
default=1, |
|
) |
|
parser.add_argument( |
|
"--pipeline_parallel", |
|
action="store_true", |
|
help="Only use if PP>1", |
|
) |
|
args = parser.parse_args() |
|
if args.pipeline_parallel: |
|
print("parallel") |
|
convert_model_pipeline( |
|
output_base_path=args.output_dir, |
|
input_base_path=os.path.join(args.input_dir, args.model_size), |
|
model_size=args.model_size, |
|
num_output_shards=args.num_output_shards, |
|
) |
|
else: |
|
print("sequential") |
|
convert_model_sequential( |
|
output_base_path=args.output_dir, |
|
input_base_path=os.path.join(args.input_dir, args.model_size), |
|
model_size=args.model_size, |
|
num_output_shards=args.num_output_shards, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|