# Copyright (c) 2024, EleutherAI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import torch import yaml import shutil from tqdm import auto as tqdm_lib VOCAB_SIZE = 50432 IGNORED_MODEL_STATE_KEYS = [ "optimizer", "random_rng_state", "np_rng_state", "torch_rng_state", "cuda_rng_state", "rng_tracker_states", ] def modify_config(input_config_path, output_config_path, output_dir): with open(input_config_path) as f: loaded_config = yaml.full_load(f) # replace model/pipeline parallel loaded_config["model_parallel_size"] = 1 loaded_config["pipe_parallel_size"] = 1 # replace load / save directories: loaded_config["load"] = output_dir loaded_config["save"] = output_dir # replace some other paths loaded_config["vocab_file"] = os.path.join(output_dir, "20B_tokenizer.json") loaded_config["log_dir"] = "./logs" # we need to make sure the resulting vocab size is correct # do this by modifying the 'make_vocab_size_divisible_by' argument to be # orig * (orig_mp / mp_out) loaded_config["make_vocab_size_divisible_by"] = VOCAB_SIZE # remove zero optimizer loaded_config["zero_optimization"]["stage"] = 0 with open(output_config_path, "w") as f: yaml.dump(loaded_config, f) def modify_model_states(input_model_state_path, output_model_state_path): model_state = torch.load(input_model_state_path) for key in IGNORED_MODEL_STATE_KEYS: del model_state[key] model_state["mp_world_size"] = 1 model_state["dp_world_size"] = 1 # could make this configurable? model_state["args"]["model_parallel_size"] = 1 model_state["args"]["make_vocab_size_divisible_by"] = VOCAB_SIZE torch.save(model_state, output_model_state_path) def merge_model_weights(input_checkpoint_path, output_checkpoint_path): pbar = tqdm_lib.tqdm(total=47) # Load transformer layers for layer_i in range(44): pbar.set_description(f"Merging layer {layer_i}") filename_tp1 = f"layer_{layer_i + 2:02d}-model_00-model_states.pt" filename_tp2 = f"layer_{layer_i + 2:02d}-model_01-model_states.pt" loaded_tp1 = torch.load(os.path.join(input_checkpoint_path, filename_tp1)) loaded_tp2 = torch.load(os.path.join(input_checkpoint_path, filename_tp2)) # noinspection PyDictCreation merged = {} # RowParallelLinear merged["mlp.dense_4h_to_h.weight"] = torch.cat( [ loaded_tp1["mlp.dense_4h_to_h.weight"], loaded_tp2["mlp.dense_4h_to_h.weight"], ], dim=1, ) merged["attention.dense.weight"] = torch.cat( [ loaded_tp1["attention.dense.weight"], loaded_tp2["attention.dense.weight"], ], dim=1, ) merged["mlp.dense_4h_to_h.bias"] = ( loaded_tp1["mlp.dense_4h_to_h.bias"] + loaded_tp2["mlp.dense_4h_to_h.bias"] ) merged["attention.dense.bias"] = ( loaded_tp1["attention.dense.bias"] + loaded_tp2["attention.dense.bias"] ) # Layer Norms merged["input_layernorm.weight"] = ( loaded_tp1["input_layernorm.weight"] + loaded_tp2["input_layernorm.weight"] ) / 2 merged["input_layernorm.bias"] = ( loaded_tp1["input_layernorm.bias"] + loaded_tp2["input_layernorm.bias"] ) / 2 merged["post_attention_layernorm.weight"] = ( loaded_tp1["post_attention_layernorm.weight"] + loaded_tp2["post_attention_layernorm.weight"] ) / 2 merged["post_attention_layernorm.bias"] = ( loaded_tp1["post_attention_layernorm.bias"] + loaded_tp2["post_attention_layernorm.bias"] ) / 2 # ColumnParallelLinear merged["mlp.dense_h_to_4h.weight"] = torch.cat( [ loaded_tp1["mlp.dense_h_to_4h.weight"], loaded_tp2["mlp.dense_h_to_4h.weight"], ], dim=0, ) merged["mlp.dense_h_to_4h.bias"] = torch.cat( [ loaded_tp1["mlp.dense_h_to_4h.bias"], loaded_tp2["mlp.dense_h_to_4h.bias"], ], dim=0, ) merged["attention.query_key_value.weight"] = torch.cat( [ loaded_tp1["attention.query_key_value.weight"], loaded_tp2["attention.query_key_value.weight"], ], dim=0, ) merged["attention.query_key_value.bias"] = torch.cat( [ loaded_tp1["attention.query_key_value.bias"], loaded_tp2["attention.query_key_value.bias"], ], dim=0, ) # Just take one merged["attention.rotary_emb.inv_freq"] = loaded_tp1[ "attention.rotary_emb.inv_freq" ] torch.save(merged, os.path.join(output_checkpoint_path, filename_tp1)) del loaded_tp1 del loaded_tp2 pbar.update(1) # Load input embedding pbar.set_description(f"Merging input embedding") loaded_tp1 = torch.load( os.path.join(input_checkpoint_path, "layer_00-model_00-model_states.pt") ) loaded_tp2 = torch.load( os.path.join(input_checkpoint_path, "layer_00-model_01-model_states.pt") ) merged = { "word_embeddings.weight": torch.cat( [ loaded_tp1["word_embeddings.weight"], loaded_tp2["word_embeddings.weight"], ], dim=0, ) } torch.save( merged, os.path.join(output_checkpoint_path, "layer_00-model_00-model_states.pt"), ) del loaded_tp1 del loaded_tp2 pbar.update(1) # Load final layer norm pbar.set_description(f"Merging final layer norm") loaded_tp1 = torch.load( os.path.join(input_checkpoint_path, "layer_47-model_00-model_states.pt") ) loaded_tp2 = torch.load( os.path.join(input_checkpoint_path, "layer_47-model_01-model_states.pt") ) merged = { "norm.weight": (loaded_tp1["norm.weight"] + loaded_tp2["norm.weight"]) / 2, "norm.bias": (loaded_tp1["norm.bias"] + loaded_tp2["norm.bias"]) / 2, } torch.save( merged, os.path.join(output_checkpoint_path, "layer_47-model_00-model_states.pt"), ) del loaded_tp1 del loaded_tp2 pbar.update(1) # Load output embedding pbar.set_description(f"Merging output embedding") loaded_tp1 = torch.load( os.path.join(input_checkpoint_path, "layer_48-model_00-model_states.pt") ) loaded_tp2 = torch.load( os.path.join(input_checkpoint_path, "layer_48-model_01-model_states.pt") ) merged = { "final_linear.weight": torch.cat( [ loaded_tp1["final_linear.weight"], loaded_tp2["final_linear.weight"], ], dim=0, ), } torch.save( merged, os.path.join(output_checkpoint_path, "layer_48-model_00-model_states.pt"), ) del loaded_tp1 del loaded_tp2 pbar.update(1) pbar.set_description("Done.") def merge(input_dir, output_dir): input_checkpoint_path = os.path.join(input_dir, "global_step150000") output_checkpoint_path = os.path.join(output_dir, "global_step150000") os.makedirs(output_checkpoint_path, exist_ok=True) os.makedirs(os.path.join(output_dir, "configs"), exist_ok=True) for i in range(8): modify_model_states( input_model_state_path=os.path.join( input_checkpoint_path, f"mp_rank_{i:02d}_model_states.pt" ), output_model_state_path=os.path.join( output_checkpoint_path, f"mp_rank_{i:02d}_model_states.pt" ), ) modify_config( input_config_path=os.path.join(input_dir, "configs", "20B.yml"), output_config_path=os.path.join(output_dir, "configs", "20B.yml"), output_dir=output_dir, ) merge_model_weights( input_checkpoint_path=input_checkpoint_path, output_checkpoint_path=output_checkpoint_path, ) shutil.copyfile( os.path.join(input_dir, "20B_tokenizer.json"), os.path.join(output_dir, "20B_tokenizer.json"), ) with open(os.path.join(output_dir, "latest"), "w") as f: f.write("global_step150000") def main(): parser = argparse.ArgumentParser(description="Merge 20B checkpoint.") parser.add_argument( "--input_dir", type=str, help='Checkpoint dir, which should contain (e.g. a folder named "global_step150000")', ) parser.add_argument( "--output_dir", type=str, help="Output dir, to save the 1-GPU weights configs" ) args = parser.parse_args() merge(args.input_dir, args.output_dir) if __name__ == "__main__": main()