| import torch | |
| vae_path = 'models/vqgan_cfw_00011.ckpt' | |
| with open(vae_path, 'rb') as f: | |
| vae_ckpt = torch.load(f, map_location='cpu') | |
| prune_keys = [] | |
| for k, v in vae_ckpt['state_dict'].items(): | |
| if 'decoder.fusion_layer' in k: | |
| prune_keys.append(k) | |
| print(k) | |
| vae_cfw = {} | |
| for k in prune_keys: | |
| vae_cfw[k] = vae_ckpt['state_dict'][k] | |
| del vae_ckpt['state_dict'][k] | |
| torch.save(vae_ckpt, 'models/vqgan_cfw_00011_vae_only.ckpt') | |
| torch.save(vae_cfw, 'models/vqgan_cfw_00011_cfw_only.ckpt') | |