Chaerin5 commited on
Commit
08b1d2f
·
1 Parent(s): e1163fb

enable zerogpu

Browse files
Files changed (1) hide show
  1. app.py +39 -41
app.py CHANGED
@@ -312,6 +312,7 @@ def get_ref_anno(ref):
312
  point_labels=input_label,
313
  multimask_output=False,
314
  )
 
315
  hand_mask = masks[0]
316
  masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
317
  ref_pose = visualize_hand(keypts, masked_img)
@@ -323,51 +324,48 @@ def get_ref_anno(ref):
323
 
324
  @spaces.GPU(duration=120)
325
  def make_ref_cond(
326
- img,
327
- keypts,
328
- hand_mask,
329
- device=device,
330
- target_size=(256, 256),
331
- latent_size=(32, 32),
332
  ):
333
- image_transform = Compose(
334
- [
335
- ToTensor(),
336
- Resize(target_size),
337
- Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
338
- ]
339
- )
340
- image = image_transform(img).to(device)
341
- kpts_valid = check_keypoints_validity(keypts, target_size)
342
- heatmaps = torch.tensor(
343
- keypoint_heatmap(
344
- scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
345
- )
346
- * kpts_valid[:, None, None],
347
- dtype=torch.float,
348
- device=device,
349
- )[None, ...]
350
- mask = torch.tensor(
351
- cv2.resize(
352
- hand_mask.astype(int),
353
- dsize=latent_size,
354
- interpolation=cv2.INTER_NEAREST,
355
- ),
356
- dtype=torch.float,
357
- device=device,
358
- ).unsqueeze(0)[None, ...]
359
  latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
360
- return image[None, ...], heatmaps, mask, latent
361
-
362
- image, heatmaps, mask, latent = make_ref_cond(
363
- img,
364
- keypts,
365
- hand_mask,
 
 
 
 
 
 
 
 
 
 
 
366
  device=device,
367
- target_size=opts.image_size,
368
- latent_size=opts.latent_size,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  )
370
-
 
371
  if not REF_POSE_MASK:
372
  heatmaps = torch.zeros_like(heatmaps)
373
  mask = torch.zeros_like(mask)
 
312
  point_labels=input_label,
313
  multimask_output=False,
314
  )
315
+ print("finished SAM")
316
  hand_mask = masks[0]
317
  masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
318
  ref_pose = visualize_hand(keypts, masked_img)
 
324
 
325
  @spaces.GPU(duration=120)
326
  def make_ref_cond(
327
+ image
 
 
 
 
 
328
  ):
329
+ print("ready to run autoencoder")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
331
+ return image[None, ...], latent
332
+
333
+ image_transform = Compose(
334
+ [
335
+ ToTensor(),
336
+ Resize(opts.image_size),
337
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
338
+ ]
339
+ )
340
+ image = image_transform(img).to(device)
341
+ kpts_valid = check_keypoints_validity(keypts, opts.image_size)
342
+ heatmaps = torch.tensor(
343
+ keypoint_heatmap(
344
+ scale_keypoint(keypts, opts.image_size, opts.latent_size), opts.latent_size, var=1.0
345
+ )
346
+ * kpts_valid[:, None, None],
347
+ dtype=torch.float,
348
  device=device,
349
+ )[None, ...]
350
+ mask = torch.tensor(
351
+ cv2.resize(
352
+ hand_mask.astype(int),
353
+ dsize=opts.latent_size,
354
+ interpolation=cv2.INTER_NEAREST,
355
+ ),
356
+ dtype=torch.float,
357
+ device=device,
358
+ ).unsqueeze(0)[None, ...]
359
+ image, latent = make_ref_cond(
360
+ image,
361
+ # keypts,
362
+ # hand_mask,
363
+ # device=device,
364
+ # target_size=opts.image_size,
365
+ # latent_size=opts.latent_size,
366
  )
367
+ print("finished autoencoder")
368
+
369
  if not REF_POSE_MASK:
370
  heatmaps = torch.zeros_like(heatmaps)
371
  mask = torch.zeros_like(mask)