File size: 514 Bytes
34097e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

vae_path = 'models/vqgan_cfw_00011.ckpt'

with open(vae_path, 'rb') as f:
    vae_ckpt = torch.load(f, map_location='cpu')

prune_keys = []
for k, v in vae_ckpt['state_dict'].items():
    if 'decoder.fusion_layer' in k:
        prune_keys.append(k)
        print(k)

vae_cfw = {}
for k in prune_keys:
    vae_cfw[k] = vae_ckpt['state_dict'][k]
    del vae_ckpt['state_dict'][k]

torch.save(vae_ckpt, 'models/vqgan_cfw_00011_vae_only.ckpt')
torch.save(vae_cfw, 'models/vqgan_cfw_00011_cfw_only.ckpt')