Spaces:
Runtime error
Runtime error
import torch, os, argparse | |
from safetensors.torch import save_file | |
def load_pl_state_dict(file_path): | |
print(f"loading {file_path}") | |
state_dict = torch.load(file_path, map_location="cpu") | |
trainable_param_names = set(state_dict["trainable_param_names"]) | |
if "module" in state_dict: | |
state_dict = state_dict["module"] | |
if "state_dict" in state_dict: | |
state_dict = state_dict["state_dict"] | |
state_dict_ = {} | |
for name, param in state_dict.items(): | |
if name.startswith("_forward_module."): | |
name = name[len("_forward_module."):] | |
if name.startswith("unet."): | |
name = name[len("unet."):] | |
if name in trainable_param_names: | |
state_dict_[name] = param | |
return state_dict_ | |
def ckpt_to_epochs(ckpt_name): | |
return int(ckpt_name.split("=")[1].split("-")[0]) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
parser.add_argument( | |
"--output_path", | |
type=str, | |
default="./", | |
help="Path to save the model.", | |
) | |
parser.add_argument( | |
"--gamma", | |
type=float, | |
default=0.9, | |
help="Gamma in EMA.", | |
) | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
# args | |
args = parse_args() | |
folder = args.output_path | |
gamma = args.gamma | |
# EMA | |
ckpt_list = sorted([(ckpt_to_epochs(ckpt_name), ckpt_name) for ckpt_name in os.listdir(folder) if os.path.isdir(f"{folder}/{ckpt_name}")]) | |
state_dict_ema = None | |
for epochs, ckpt_name in ckpt_list: | |
state_dict = load_pl_state_dict(f"{folder}/{ckpt_name}/checkpoint/mp_rank_00_model_states.pt") | |
if state_dict_ema is None: | |
state_dict_ema = {name: param.float() for name, param in state_dict.items()} | |
else: | |
for name, param in state_dict.items(): | |
state_dict_ema[name] = state_dict_ema[name] * gamma + param.float() * (1 - gamma) | |
save_path = ckpt_name.replace(".ckpt", "-ema.safetensors") | |
print(f"save to {folder}/{save_path}") | |
save_file(state_dict_ema, f"{folder}/{save_path}") | |