wwen1997 commited on
Commit
5b24944
1 Parent(s): 5e58ae1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -362,6 +362,7 @@ def validate_and_convert_image(image, target_size=(512 , 512)):
362
 
363
  class Drag:
364
 
 
365
  def __init__(self, device, args, height, width, model_length, dtype=torch.float16, use_sift=False):
366
  self.device = device
367
  self.dtype = dtype
@@ -379,14 +380,14 @@ class Drag:
379
  )
380
  controlnet = controlnet.to(device, dtype)
381
 
382
- # if is_xformers_available():
383
- # import xformers
384
- # xformers_version = version.parse(xformers.__version__)
385
- # unet.enable_xformers_memory_efficient_attention()
386
- # # controlnet.enable_xformers_memory_efficient_attention()
387
- # else:
388
- # raise ValueError(
389
- # "xformers is not available. Make sure it is installed correctly")
390
 
391
  pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
392
  "checkpoints/stable-video-diffusion-img2vid-xt",
@@ -406,6 +407,7 @@ class Drag:
406
  self.model_length = model_length
407
  self.use_sift = use_sift
408
 
 
409
  def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
410
  original_width, original_height = 512, 320 # TODO
411
 
 
362
 
363
  class Drag:
364
 
365
+ @spaces.GPU
366
  def __init__(self, device, args, height, width, model_length, dtype=torch.float16, use_sift=False):
367
  self.device = device
368
  self.dtype = dtype
 
380
  )
381
  controlnet = controlnet.to(device, dtype)
382
 
383
+ if is_xformers_available():
384
+ import xformers
385
+ xformers_version = version.parse(xformers.__version__)
386
+ unet.enable_xformers_memory_efficient_attention()
387
+ # controlnet.enable_xformers_memory_efficient_attention()
388
+ else:
389
+ raise ValueError(
390
+ "xformers is not available. Make sure it is installed correctly")
391
 
392
  pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
393
  "checkpoints/stable-video-diffusion-img2vid-xt",
 
407
  self.model_length = model_length
408
  self.use_sift = use_sift
409
 
410
+ @spaces.GPU
411
  def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
412
  original_width, original_height = 512, 320 # TODO
413