Chaerin5 commited on
Commit
6097648
·
1 Parent(s): 32fa016

fix vae nan bug

Browse files
Files changed (1) hide show
  1. app.py +19 -136
app.py CHANGED
@@ -217,21 +217,7 @@ if NEW_MODEL:
217
  model.eval()
218
  print(missing_keys, extra_keys)
219
  assert len(missing_keys) == 0
220
- vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
221
- print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
222
- autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
223
- print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
224
- print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
225
- print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
226
- missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
227
- print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
228
- print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
229
- autoencoder = autoencoder.to(device)
230
- autoencoder.eval()
231
- print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
232
- print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
233
- print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
234
- assert len(missing_keys) == 0
235
  # else:
236
  # opts = HandDiffOpts()
237
  # model_path = './finetune_epoch=5-step=130000.ckpt'
@@ -266,127 +252,6 @@ hands = mp_hands.Hands(
266
  min_detection_confidence=0.1,
267
  )
268
 
269
- # def make_ref_cond(
270
- # image
271
- # ):
272
- # print("ready to run autoencoder")
273
- # # print(f"image.device: {image.device}, type(image): {type(image)}")
274
- # # image = image.to("cuda")
275
- # print(f"autoencoder device: {next(autoencoder.parameters()).device}")
276
- # latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
277
- # return image[None, ...], latent
278
-
279
-
280
- # def get_ref_anno(ref):
281
- # print("inside get_ref_anno")
282
- # if ref is None:
283
- # return (
284
- # None,
285
- # None,
286
- # None,
287
- # None,
288
- # None,
289
- # )
290
- # img = ref["composite"][..., :3]
291
- # img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
292
- # keypts = np.zeros((42, 2))
293
- # print("ready to run mediapipe")
294
- # if REF_POSE_MASK:
295
- # print(f"type(img): {type(img)}, img.shape: {img.shape}, img.dtype: {img.dtype}")
296
- # mp_pose = hands.process(img)
297
- # print("processed mediapipe")
298
- # detected = np.array([0, 0])
299
- # start_idx = 0
300
- # if mp_pose.multi_hand_landmarks:
301
- # # handedness is flipped assuming the input image is mirrored in MediaPipe
302
- # for hand_landmarks, handedness in zip(
303
- # mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
304
- # ):
305
- # # actually right hand
306
- # if handedness.classification[0].label == "Left":
307
- # start_idx = 0
308
- # detected[0] = 1
309
- # # actually left hand
310
- # elif handedness.classification[0].label == "Right":
311
- # start_idx = 21
312
- # detected[1] = 1
313
- # for i, landmark in enumerate(hand_landmarks.landmark):
314
- # keypts[start_idx + i] = [
315
- # landmark.x * opts.image_size[1],
316
- # landmark.y * opts.image_size[0],
317
- # ]
318
-
319
- # sam_predictor.set_image(img)
320
- # l = keypts[:21].shape[0]
321
- # if keypts[0].sum() != 0 and keypts[21].sum() != 0:
322
- # input_point = np.array([keypts[0], keypts[21]])
323
- # input_label = np.array([1, 1])
324
- # elif keypts[0].sum() != 0:
325
- # input_point = np.array(keypts[:1])
326
- # input_label = np.array([1])
327
- # elif keypts[21].sum() != 0:
328
- # input_point = np.array(keypts[21:22])
329
- # input_label = np.array([1])
330
- # print("ready to run SAM")
331
- # masks, _, _ = sam_predictor.predict(
332
- # point_coords=input_point,
333
- # point_labels=input_label,
334
- # multimask_output=False,
335
- # )
336
- # print("finished SAM")
337
- # hand_mask = masks[0]
338
- # masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
339
- # ref_pose = visualize_hand(keypts, masked_img)
340
- # else:
341
- # raise gr.Error("No hands detected in the reference image.")
342
- # else:
343
- # hand_mask = np.zeros_like(img[:,:, 0])
344
- # ref_pose = np.zeros_like(img)
345
-
346
- # image_transform = Compose(
347
- # [
348
- # ToTensor(),
349
- # Resize(opts.image_size),
350
- # Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
351
- # ]
352
- # )
353
- # image = image_transform(img)
354
- # kpts_valid = check_keypoints_validity(keypts, opts.image_size)
355
- # heatmaps = torch.tensor(
356
- # keypoint_heatmap(
357
- # scale_keypoint(keypts, opts.image_size, opts.latent_size), opts.latent_size, var=1.0
358
- # )
359
- # * kpts_valid[:, None, None],
360
- # dtype=torch.float,
361
- # # device=device,
362
- # )[None, ...]
363
- # mask = torch.tensor(
364
- # cv2.resize(
365
- # hand_mask.astype(int),
366
- # dsize=opts.latent_size,
367
- # interpolation=cv2.INTER_NEAREST,
368
- # ),
369
- # dtype=torch.float,
370
- # # device=device,
371
- # ).unsqueeze(0)[None, ...]
372
- # image, latent = make_ref_cond(
373
- # image,
374
- # # keypts,
375
- # # hand_mask,
376
- # # device=device,
377
- # # target_size=opts.image_size,
378
- # # latent_size=opts.latent_size,
379
- # )
380
- # print("finished autoencoder")
381
-
382
- # if not REF_POSE_MASK:
383
- # heatmaps = torch.zeros_like(heatmaps)
384
- # mask = torch.zeros_like(mask)
385
- # ref_cond = torch.cat([latent, heatmaps, mask], 1)
386
-
387
- # return img, ref_pose, ref_cond
388
-
389
-
390
  def get_ref_anno(ref):
391
  if ref is None:
392
  return (
@@ -396,6 +261,24 @@ def get_ref_anno(ref):
396
  None,
397
  None,
398
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  img = ref["composite"][..., :3]
400
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
401
  keypts = np.zeros((42, 2))
 
217
  model.eval()
218
  print(missing_keys, extra_keys)
219
  assert len(missing_keys) == 0
220
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  # else:
222
  # opts = HandDiffOpts()
223
  # model_path = './finetune_epoch=5-step=130000.ckpt'
 
252
  min_detection_confidence=0.1,
253
  )
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  def get_ref_anno(ref):
256
  if ref is None:
257
  return (
 
261
  None,
262
  None,
263
  )
264
+
265
+ vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
266
+ print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
267
+ autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
268
+ print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
269
+ print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
270
+ print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
271
+ missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
272
+ print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
273
+ print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
274
+ autoencoder = autoencoder.to(device)
275
+ autoencoder.eval()
276
+ print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
277
+ print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
278
+ print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
279
+ assert len(missing_keys) == 0
280
+
281
+
282
  img = ref["composite"][..., :3]
283
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
284
  keypts = np.zeros((42, 2))