Chaerin5 commited on
Commit
385c0f2
·
1 Parent(s): 6097648

fix vae nan bug

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -217,7 +217,21 @@ if NEW_MODEL:
217
  model.eval()
218
  print(missing_keys, extra_keys)
219
  assert len(missing_keys) == 0
220
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  # else:
222
  # opts = HandDiffOpts()
223
  # model_path = './finetune_epoch=5-step=130000.ckpt'
@@ -261,24 +275,8 @@ def get_ref_anno(ref):
261
  None,
262
  None,
263
  )
264
-
265
- vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
266
- print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
267
- autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
268
- print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
269
- print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
270
- print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
271
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
272
- print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
273
- print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
274
- autoencoder = autoencoder.to(device)
275
- autoencoder.eval()
276
- print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
277
- print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
278
- print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
279
- assert len(missing_keys) == 0
280
 
281
-
282
  img = ref["composite"][..., :3]
283
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
284
  keypts = np.zeros((42, 2))
 
217
  model.eval()
218
  print(missing_keys, extra_keys)
219
  assert len(missing_keys) == 0
220
+ vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
221
+ print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
222
+ autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
223
+ print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
224
+ print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
225
+ print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
226
+ missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
227
+ print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
228
+ print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
229
+ autoencoder = autoencoder.to(device)
230
+ autoencoder.eval()
231
+ print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
232
+ print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
233
+ print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
234
+ assert len(missing_keys) == 0
235
  # else:
236
  # opts = HandDiffOpts()
237
  # model_path = './finetune_epoch=5-step=130000.ckpt'
 
275
  None,
276
  None,
277
  )
 
 
 
 
 
 
 
278
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
 
 
 
 
 
 
 
 
279
 
 
280
  img = ref["composite"][..., :3]
281
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
282
  keypts = np.zeros((42, 2))