# Copyright (c) 2024, EleutherAI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import torch import json import math import tqdm.auto as tqdm INTERMEDIATE_SIZE_MAP = { "7B": 11008, "13B": 13824, "30B": 17920, "34B": 22016, "65B": 22016, "70B": 28672, "mistral-7B-v0.1": 14336, } NUM_SHARDS = { "7B": 1, "13B": 2, "30B": 4, "34B": 4, "65B": 8, "70B": 8, "mistral-7B-v0.1": 1, } def compute_intermediate_size(n): return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 def read_json(path): with open(path, "r") as f: return json.load(f) def write_json(text, path): with open(path, "w") as f: json.dump(text, f) def write_file(text, path): with open(path, "w") as f: f.write(text) def convert_model_pipeline( output_base_path, input_base_path, model_size: str, num_output_shards: int ): assert model_size in NUM_SHARDS model_path = os.path.join(output_base_path, "global_step0") os.makedirs(model_path, exist_ok=True) write_file("global_step0", os.path.join(output_base_path, "latest")) params = read_json(os.path.join(input_base_path, "params.json")) num_input_shards = NUM_SHARDS[model_size] num_layers = params["n_layers"] num_heads = params["n_heads"] if "n_kv_heads" in params: num_kv_heads = params["n_kv_heads"] else: num_kv_heads = num_heads num_kv_heads_per_input_shard = num_kv_heads // num_input_shards num_heads_per_input_shard = num_heads // num_input_shards num_heads_per_output_shard = num_heads // num_output_shards num_kv_heads_per_output_shard = num_kv_heads // num_output_shards hidden_size = params["dim"] dims_per_head = hidden_size // num_heads # base = 10000.0 # inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) def permute_rotary(w): if w.shape == (num_heads, dims_per_head, hidden_size): N_HEADS = num_heads elif w.shape == (num_kv_heads, dims_per_head, hidden_size): N_HEADS = num_kv_heads else: assert False return ( w.view(N_HEADS, dims_per_head // 2, 2, hidden_size) .transpose(1, 2) .reshape(N_HEADS, dims_per_head, hidden_size) ) pbar = tqdm.tqdm(total=num_input_shards + num_layers + 3) pbar.set_description(f"Loading shard") loaded = [] for i in range(num_input_shards): loaded.append( torch.load( os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu", ) ) pbar.set_description(f"Loaded shard {i}/{num_input_shards}") pbar.update(1) helper = Helper( loaded=loaded, model_path=model_path, num_output_shards=num_output_shards, model_size=model_size, pipeline_parallel=False, ) sequential_cache = [{} for _ in range(num_output_shards)] # Embedding in embeddings_in = torch.cat( [ loaded[rank]["tok_embeddings.weight"].cpu() for rank in range(num_input_shards) ], dim=1, ) print(embeddings_in.shape) helper.save_shards( {"word_embeddings.weight": helper.shard(embeddings_in, dim=0)}, layer_i=0 ) helper.del_loaded("tok_embeddings.weight") pbar.set_description(f"Saved embeddings") pbar.update(1) # Norms helper.save_duplicates( {"norm.scale": loaded[0]["norm.weight"]}, layer_i=num_layers + 3 ) helper.del_loaded("norm.weight") pbar.set_description(f"Saved final norm") pbar.update(1) # Embedding out embeddings_out = torch.cat( [loaded[rank]["output.weight"].cpu() for rank in range(num_input_shards)], dim=0 ) helper.save_shards( {"final_linear.weight": helper.shard(embeddings_out, dim=0)}, layer_i=num_layers + 4, ) helper.del_loaded("output.weight") pbar.set_description(f"Saved out embeddings") pbar.update(1) # Layers for layer_i in range(num_layers): # Linear attn_wo = helper.shard( torch.cat( [ loaded[rank][f"layers.{layer_i}.attention.wo.weight"] for rank in range(num_input_shards) ], dim=1, ), dim=1, ) mlp_w1 = helper.shard( torch.cat( [ loaded[rank][f"layers.{layer_i}.feed_forward.w1.weight"] for rank in range(num_input_shards) ], dim=0, ), dim=0, ) mlp_w2 = helper.shard( torch.cat( [ loaded[rank][f"layers.{layer_i}.feed_forward.w2.weight"] for rank in range(num_input_shards) ], dim=1, ), dim=1, ) mlp_w3 = helper.shard( torch.cat( [ loaded[rank][f"layers.{layer_i}.feed_forward.w3.weight"] for rank in range(num_input_shards) ], dim=0, ), dim=0, ) helper.del_loaded(f"layers.{layer_i}.attention.wo.weight") helper.del_loaded(f"layers.{layer_i}.feed_forward.w1.weight") helper.del_loaded(f"layers.{layer_i}.feed_forward.w2.weight") helper.del_loaded(f"layers.{layer_i}.feed_forward.w3.weight") # Attention w_q = permute_rotary( torch.cat( [ loaded[rank][f"layers.{layer_i}.attention.wq.weight"].view( num_heads_per_input_shard, dims_per_head, hidden_size ) for rank in range(num_input_shards) ], dim=0, ) ) w_k = permute_rotary( torch.cat( [ loaded[rank][f"layers.{layer_i}.attention.wk.weight"].view( num_kv_heads_per_input_shard, dims_per_head, hidden_size ) for rank in range(num_input_shards) ], dim=0, ) ).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size) w_v = torch.cat( [ loaded[rank][f"layers.{layer_i}.attention.wv.weight"].view( num_kv_heads_per_input_shard, dims_per_head, hidden_size ) for rank in range(num_input_shards) ], dim=0, ).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size) sharded_qkv = torch.cat( [ helper.shard( w_q, dim=0 ), # num_output_shards, num_heads_per_output_shard, dims_per_head, hidden_size helper.shard(w_k, dim=0), helper.shard(w_v, dim=0), ], dim=2, ) # num_output_shards, num_heads_per_output_shard, QKV=3, dims_per_head, hidden_size sharded_qkv = sharded_qkv.view( num_output_shards, num_heads_per_output_shard * dims_per_head + 2 * num_kv_heads_per_output_shard * dims_per_head, hidden_size, ) helper.del_loaded(f"layers.{layer_i}.attention.wq.weight") helper.del_loaded(f"layers.{layer_i}.attention.wk.weight") helper.del_loaded(f"layers.{layer_i}.attention.wv.weight") # Duplicated input_layernorm = loaded[0][f"layers.{layer_i}.attention_norm.weight"] post_attention_layernorm = loaded[0][f"layers.{layer_i}.ffn_norm.weight"] helper.del_loaded(f"layers.{layer_i}.attention_norm.weight") helper.del_loaded(f"layers.{layer_i}.ffn_norm.weight") for out_rank in range(num_output_shards): helper.save( { "attention.query_key_value.weight": sharded_qkv[out_rank], # Sharded layers "attention.dense.weight": attn_wo[out_rank].clone(), "mlp.w1.weight": mlp_w1[out_rank].clone(), "mlp.w2.weight": mlp_w2[out_rank].clone(), "mlp.w3.weight": mlp_w3[out_rank].clone(), # Duplicated layers "input_layernorm.scale": input_layernorm, "post_attention_layernorm.scale": post_attention_layernorm, }, layer_i=layer_i + 2, rank=out_rank, ) pbar.set_description(f"Saved layer {layer_i} / {num_layers}") pbar.update(1) model_state = { "dp_world_size": 1, "mp_world_size": num_output_shards, "module": {}, "optimizer": {}, "global_steps": 1, "skipped_steps": 1, "iteration": 1, } for rank in range(num_output_shards): torch.save( model_state, os.path.join(model_path, f"mp_rank_{rank:02d}_model_states.pt") ) pbar.set_description("Done.") def convert_model_sequential( output_base_path, input_base_path, model_size: str, num_output_shards: int ): assert model_size in NUM_SHARDS model_path = os.path.join(output_base_path, "global_step0") os.makedirs(model_path, exist_ok=True) write_file("global_step0", os.path.join(output_base_path, "latest")) params = read_json(os.path.join(input_base_path, "params.json")) num_input_shards = NUM_SHARDS[model_size] num_layers = params["n_layers"] num_heads = params["n_heads"] if "n_kv_heads" in params: num_kv_heads = params["n_kv_heads"] else: num_kv_heads = num_heads num_kv_heads_per_input_shard = num_kv_heads // num_input_shards num_heads_per_input_shard = num_heads // num_input_shards num_heads_per_output_shard = num_heads // num_output_shards num_kv_heads_per_output_shard = num_kv_heads // num_output_shards hidden_size = params["dim"] dims_per_head = hidden_size // num_heads # base = 10000.0 # inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) def permute_rotary(w): if w.shape == (num_heads, dims_per_head, hidden_size): N_HEADS = num_heads elif w.shape == (num_kv_heads, dims_per_head, hidden_size): N_HEADS = num_kv_heads else: assert False return ( w.view(N_HEADS, dims_per_head // 2, 2, hidden_size) .transpose(1, 2) .reshape(N_HEADS, dims_per_head, hidden_size) ) pbar = tqdm.tqdm(total=num_input_shards + num_output_shards) pbar.set_description(f"Loading shard") loaded = [] for i in range(num_input_shards): loaded.append( torch.load( os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu", ) ) pbar.set_description(f"Loaded shard {i}/{num_input_shards}") pbar.update(1) helper = Helper( loaded=loaded, model_path=model_path, num_output_shards=num_output_shards, model_size=model_size, pipeline_parallel=False, ) # Embedding in embeddings_in = torch.cat( [ loaded[rank]["tok_embeddings.weight"].cpu() for rank in range(num_input_shards) ], dim=1, ) helper.add_sequential_shard( {"word_embeddings.weight": helper.shard(embeddings_in, dim=0)}, layer_i=0 ) helper.del_loaded("tok_embeddings.weight") # Norms helper.add_sequential_duplicates( {"norm.scale": loaded[0]["norm.weight"]}, layer_i=num_layers + 3 ) helper.del_loaded("norm.weight") # Embedding out embeddings_out = torch.cat( [loaded[rank]["output.weight"].cpu() for rank in range(num_input_shards)], dim=0 ) helper.add_sequential_shard( {"final_linear.weight": helper.shard(embeddings_out, dim=0)}, layer_i=num_layers + 4, ) helper.del_loaded("output.weight") # Layers for layer_i in range(num_layers): # Linear attn_wo = helper.shard( torch.cat( [ loaded[rank][f"layers.{layer_i}.attention.wo.weight"] for rank in range(num_input_shards) ], dim=1, ), dim=1, ) mlp_w1 = helper.shard( torch.cat( [ loaded[rank][f"layers.{layer_i}.feed_forward.w1.weight"] for rank in range(num_input_shards) ], dim=0, ), dim=0, ) mlp_w2 = helper.shard( torch.cat( [ loaded[rank][f"layers.{layer_i}.feed_forward.w2.weight"] for rank in range(num_input_shards) ], dim=1, ), dim=1, ) mlp_w3 = helper.shard( torch.cat( [ loaded[rank][f"layers.{layer_i}.feed_forward.w3.weight"] for rank in range(num_input_shards) ], dim=0, ), dim=0, ) helper.del_loaded(f"layers.{layer_i}.attention.wo.weight") helper.del_loaded(f"layers.{layer_i}.feed_forward.w1.weight") helper.del_loaded(f"layers.{layer_i}.feed_forward.w2.weight") helper.del_loaded(f"layers.{layer_i}.feed_forward.w3.weight") # Attention w_q = permute_rotary( torch.cat( [ loaded[rank][f"layers.{layer_i}.attention.wq.weight"].view( num_heads_per_input_shard, dims_per_head, hidden_size ) for rank in range(num_input_shards) ], dim=0, ) ) w_k = permute_rotary( torch.cat( [ loaded[rank][f"layers.{layer_i}.attention.wk.weight"].view( num_kv_heads_per_input_shard, dims_per_head, hidden_size ) for rank in range(num_input_shards) ], dim=0, ) ).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size) w_v = torch.cat( [ loaded[rank][f"layers.{layer_i}.attention.wv.weight"].view( num_kv_heads_per_input_shard, dims_per_head, hidden_size ) for rank in range(num_input_shards) ], dim=0, ).view(num_heads, int(dims_per_head * (num_kv_heads / num_heads)), hidden_size) sharded_qkv = torch.cat( [ helper.shard( w_q, dim=0 ), # num_output_shards, num_heads_per_output_shard, dims_per_head, hidden_size helper.shard(w_k, dim=0), helper.shard(w_v, dim=0), ], dim=2, ) # num_output_shards, num_heads_per_output_shard, QKV=3, dims_per_head, hidden_size sharded_qkv = sharded_qkv.view( num_output_shards, num_heads_per_output_shard * dims_per_head + 2 * num_kv_heads_per_output_shard * dims_per_head, hidden_size, ) helper.del_loaded(f"layers.{layer_i}.attention.wq.weight") helper.del_loaded(f"layers.{layer_i}.attention.wk.weight") helper.del_loaded(f"layers.{layer_i}.attention.wv.weight") # Duplicated input_layernorm = loaded[0][f"layers.{layer_i}.attention_norm.weight"] post_attention_layernorm = loaded[0][f"layers.{layer_i}.ffn_norm.weight"] helper.del_loaded(f"layers.{layer_i}.attention_norm.weight") helper.del_loaded(f"layers.{layer_i}.ffn_norm.weight") for out_rank in range(num_output_shards): helper.add_sequential( { "attention.query_key_value.weight": sharded_qkv[out_rank], # Sharded layers "attention.dense.weight": attn_wo[out_rank].clone(), "mlp.w1.weight": mlp_w1[out_rank].clone(), "mlp.w2.weight": mlp_w2[out_rank].clone(), "mlp.w3.weight": mlp_w3[out_rank].clone(), # Duplicated layers "input_layernorm.scale": input_layernorm, "post_attention_layernorm.scale": post_attention_layernorm, }, layer_i=layer_i + 2, rank=out_rank, ) for rank in range(num_output_shards): model_state = { "dp_world_size": 1, "mp_world_size": num_output_shards, "module": helper.sequential_cache[rank], "optimizer": {}, "global_steps": 1, "skipped_steps": 1, "iteration": 1, } torch.save( model_state, os.path.join(model_path, f"mp_rank_{rank:02d}_model_states.pt") ) pbar.set_description(f"Saved shard {rank}") pbar.update(1) pbar.set_description("Done.") class Helper: def __init__( self, loaded, model_size, num_output_shards, model_path, pipeline_parallel ): self.loaded = loaded self.model_size = model_size self.num_output_shards = num_output_shards self.model_path = model_path self.pipeline_parallel = pipeline_parallel self.sequential_cache = [{} for _ in range(num_output_shards)] def del_loaded(self, key: str): # Remove from memory as we go along for loaded_shared in self.loaded: del loaded_shared[key] def save_shards(self, dictionary, layer_i: int): for k, v in dictionary.items(): assert v.shape[0] == self.num_output_shards for rank in range(self.num_output_shards): torch.save( {k: v[rank].clone() for k, v in dictionary.items()}, self.save_path(layer_i=layer_i, rank=rank), ) def save_duplicates(self, dictionary, layer_i: int): for rank in range(self.num_output_shards): torch.save( {k: v.clone() for k, v in dictionary.items()}, self.save_path(layer_i=layer_i, rank=rank), ) def save(self, obj, layer_i, rank): torch.save(obj, self.save_path(layer_i=layer_i, rank=rank)) def shard(self, x, dim): x_shape = list(x.shape) assert x_shape[dim] % self.num_output_shards == 0 new_x_shape = ( x_shape[:dim] + [self.num_output_shards, x_shape[dim] // self.num_output_shards] + x_shape[dim + 1 :] ) x = x.view(*new_x_shape) return torch.movedim(x, 0, dim) def save_path(self, layer_i, rank): return os.path.join( self.model_path, f"layer_{layer_i:02d}-model_{rank:02d}-model_states.pt" ) def add_sequential_shard(self, dictionary, layer_i): assert not self.pipeline_parallel for k, v in dictionary.items(): for rank in range(self.num_output_shards): self.sequential_cache[rank][f"sequential.{layer_i}.{k}"] = v[ rank ].clone() def add_sequential_duplicates(self, dictionary, layer_i): assert not self.pipeline_parallel for k, v in dictionary.items(): for rank in range(self.num_output_shards): self.sequential_cache[rank][f"sequential.{layer_i}.{k}"] = v.clone() def add_sequential(self, dictionary, layer_i, rank): assert not self.pipeline_parallel for k, v in dictionary.items(): self.sequential_cache[rank][f"sequential.{layer_i}.{k}"] = v.clone() def main(): parser = argparse.ArgumentParser( description="Convert raw LLaMA or Mistral checkpoints to GPT-NeoX format." ) parser.add_argument( "--input_dir", help="Location of parent directory, which contains tokenizer.model and model weights subfolders", ) parser.add_argument( "--model_size", choices=["7B", "mistral-7B-v0.1", "13B", "30B", "34B", "65B", "tokenizer_only"], ) parser.add_argument( "--output_dir", help="Location to write GPT-NeoX model", ) parser.add_argument( "--num_output_shards", type=int, default=1, ) parser.add_argument( "--pipeline_parallel", action="store_true", help="Only use if PP>1", ) args = parser.parse_args() if args.pipeline_parallel: print("parallel") convert_model_pipeline( output_base_path=args.output_dir, input_base_path=os.path.join(args.input_dir, args.model_size), model_size=args.model_size, num_output_shards=args.num_output_shards, ) else: print("sequential") convert_model_sequential( output_base_path=args.output_dir, input_base_path=os.path.join(args.input_dir, args.model_size), model_size=args.model_size, num_output_shards=args.num_output_shards, ) if __name__ == "__main__": main()