wwen1997 commited on
Commit
771c278
1 Parent(s): 15c0212

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -352,7 +352,6 @@ def validate_and_convert_image(image, target_size=(512 , 512)):
352
 
353
  class Drag:
354
 
355
- @spaces.GPU
356
  def __init__(self, device, args, height, width, model_length, dtype=torch.float16, use_sift=False):
357
  self.device = device
358
  self.dtype = dtype
@@ -363,12 +362,12 @@ class Drag:
363
  low_cpu_mem_usage=True,
364
  custom_resume=True,
365
  )
366
- unet = unet.to(device, dtype)
367
 
368
  controlnet = ControlNetSVDModel.from_pretrained(
369
  os.path.join(args.model, "controlnet"),
370
  )
371
- controlnet = controlnet.to(device, dtype)
372
 
373
  if is_xformers_available():
374
  import xformers
@@ -386,7 +385,6 @@ class Drag:
386
  low_cpu_mem_usage=False,
387
  torch_dtype=torch.float16, variant="fp16", local_files_only=True,
388
  )
389
- pipe.to(device)
390
 
391
  self.pipeline = pipe
392
  # self.pipeline.enable_model_cpu_offload()
@@ -399,6 +397,9 @@ class Drag:
399
 
400
  @spaces.GPU
401
  def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
 
 
 
402
  original_width, original_height = 512, 320 # TODO
403
 
404
  # load_image
 
352
 
353
  class Drag:
354
 
 
355
  def __init__(self, device, args, height, width, model_length, dtype=torch.float16, use_sift=False):
356
  self.device = device
357
  self.dtype = dtype
 
362
  low_cpu_mem_usage=True,
363
  custom_resume=True,
364
  )
365
+ unet = unet.to(dtype)
366
 
367
  controlnet = ControlNetSVDModel.from_pretrained(
368
  os.path.join(args.model, "controlnet"),
369
  )
370
+ controlnet = controlnet.to(dtype)
371
 
372
  if is_xformers_available():
373
  import xformers
 
385
  low_cpu_mem_usage=False,
386
  torch_dtype=torch.float16, variant="fp16", local_files_only=True,
387
  )
 
388
 
389
  self.pipeline = pipe
390
  # self.pipeline.enable_model_cpu_offload()
 
397
 
398
  @spaces.GPU
399
  def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
400
+
401
+ self.pipeline.to(self.device)
402
+
403
  original_width, original_height = 512, 320 # TODO
404
 
405
  # load_image