File size: 514 Bytes
34097e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
vae_path = 'models/vqgan_cfw_00011.ckpt'
with open(vae_path, 'rb') as f:
vae_ckpt = torch.load(f, map_location='cpu')
prune_keys = []
for k, v in vae_ckpt['state_dict'].items():
if 'decoder.fusion_layer' in k:
prune_keys.append(k)
print(k)
vae_cfw = {}
for k in prune_keys:
vae_cfw[k] = vae_ckpt['state_dict'][k]
del vae_ckpt['state_dict'][k]
torch.save(vae_ckpt, 'models/vqgan_cfw_00011_vae_only.ckpt')
torch.save(vae_cfw, 'models/vqgan_cfw_00011_cfw_only.ckpt')
|