wchai commited on
Commit
d515943
·
1 Parent(s): 3f6c15a

convert to CPU

Browse files
annotator/midas/__init__.py CHANGED
@@ -8,13 +8,13 @@ from .api import MiDaSInference
8
 
9
  class MidasDetector:
10
  def __init__(self):
11
- self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
12
 
13
  def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
14
  assert input_image.ndim == 3
15
  image_depth = input_image
16
  with torch.no_grad():
17
- image_depth = torch.from_numpy(image_depth).float().cuda()
18
  image_depth = image_depth / 127.5 - 1.0
19
  image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
20
  depth = self.model(image_depth)[0]
 
8
 
9
  class MidasDetector:
10
  def __init__(self):
11
+ self.model = MiDaSInference(model_type="dpt_hybrid")
12
 
13
  def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
14
  assert input_image.ndim == 3
15
  image_depth = input_image
16
  with torch.no_grad():
17
+ image_depth = torch.from_numpy(image_depth).float()
18
  image_depth = image_depth / 127.5 - 1.0
19
  image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
20
  depth = self.model(image_depth)[0]
app.py CHANGED
@@ -48,7 +48,7 @@ class StableVideo:
48
  ):
49
  self.apply_canny = CannyDetector()
50
  canny_model = create_model(base_cfg).cpu()
51
- canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='cuda'), strict=False)
52
  self.canny_ddim_sampler = DDIMSampler(canny_model)
53
  self.canny_model = canny_model
54
 
@@ -59,7 +59,7 @@ class StableVideo:
59
  ):
60
  self.apply_midas = MidasDetector()
61
  depth_model = create_model(base_cfg).cpu()
62
- depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='cuda'), strict=False)
63
  self.depth_ddim_sampler = DDIMSampler(depth_model)
64
  self.depth_model = depth_model
65
 
@@ -101,7 +101,7 @@ class StableVideo:
101
 
102
  detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
103
 
104
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
105
  control = torch.stack([control for _ in range(1)], dim=0)
106
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
107
 
@@ -128,7 +128,7 @@ class StableVideo:
128
 
129
  @torch.no_grad()
130
  def edit_background(self, *args, **kwargs):
131
- self.depth_model = self.depth_model.cuda()
132
 
133
  input_image = self.b_atlas_origin
134
  self.depth_edit(input_image, *args, **kwargs)
@@ -155,7 +155,7 @@ class StableVideo:
155
  if_net=False,
156
  num_samples=1):
157
 
158
- self.canny_model = self.canny_model.cuda()
159
 
160
  keyframes = [int(x) for x in keyframes.split(",")]
161
  if self.data is None:
@@ -186,7 +186,7 @@ class StableVideo:
186
  # get canny control
187
  detected_map = self.apply_canny(img, low_threshold, high_threshold)
188
  detected_map = HWC3(detected_map)
189
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
190
  control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
191
 
192
  cond = {"c_concat": [control], "c_crossattn": c_crossattn}
@@ -195,7 +195,7 @@ class StableVideo:
195
 
196
  # if not the key frame, calculate the mapping from last atlas
197
  if i == 0:
198
- latent = torch.randn((1, 4, H // 8, W // 8)).cuda()
199
  samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
200
  shape, cond, verbose=False, eta=eta,
201
  unconditional_guidance_scale=scale,
@@ -209,7 +209,7 @@ class StableVideo:
209
  mapped_img = mapped_img.resize((W, H))
210
  mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
211
  mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
212
- mapped_img = torch.from_numpy(mapped_img).cuda()
213
  mapped_img = 2. * mapped_img - 1.
214
  latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
215
 
@@ -232,7 +232,7 @@ class StableVideo:
232
  result = alpha * result
233
 
234
  # buffer for training
235
- result_copy = result.clone().cuda()
236
  result_copy.requires_grad = True
237
  result_list.append(result_copy)
238
 
@@ -259,7 +259,7 @@ class StableVideo:
259
  # aggregate net #
260
  #####################################
261
  lr, n_epoch = 1e-3, 500
262
- agg_net = AGGNet().cuda()
263
  loss_fn = nn.L1Loss()
264
  optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
265
  for _ in range(n_epoch):
@@ -291,12 +291,12 @@ class StableVideo:
291
  def render(self, f_atlas, b_atlas):
292
  # foreground
293
  if f_atlas == None:
294
- f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0).cuda()
295
  else:
296
  f_atlas, mask = f_atlas["image"], f_atlas["mask"]
297
- f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0).cuda()
298
- f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0).cuda()
299
- mask = transforms.ToTensor()(mask).unsqueeze(0).cuda()
300
  if f_atlas.shape != mask.shape:
301
  print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
302
  mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
@@ -326,7 +326,7 @@ class StableVideo:
326
  if b_atlas == None:
327
  b_atlas = self.b_atlas_origin
328
 
329
- b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0).cuda()
330
  background_edit = F.grid_sample(
331
  b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
332
  ).clamp(min=0.0, max=1.0)
@@ -349,99 +349,98 @@ class StableVideo:
349
  return save_name
350
 
351
  if __name__ == '__main__':
352
- with torch.cuda.amp.autocast():
353
- stablevideo = StableVideo(base_cfg="ckpt/cldm_v15.yaml",
354
- canny_model_cfg="ckpt/control_sd15_canny.pth",
355
- depth_model_cfg="ckpt/control_sd15_depth.pth",
356
- save_memory=True)
357
- stablevideo.load_canny_model()
358
- stablevideo.load_depth_model()
359
-
360
- block = gr.Blocks().queue()
361
- with block:
362
- with gr.Row():
363
- gr.Markdown("## StableVideo")
364
- with gr.Row():
365
- with gr.Column():
366
- original_video = gr.Video(label="Original Video", interactive=False)
367
- with gr.Row():
368
- foreground_atlas = gr.Image(label="Foreground Atlas", type="pil")
369
- background_atlas = gr.Image(label="Background Atlas", type="pil")
370
- gr.Markdown("### Step 1. select one example video and click **Load Video** buttom and wait for 10 sec.")
371
- avail_video = [f.name for f in os.scandir("data") if f.is_dir()]
372
- video_name = gr.Radio(choices=avail_video,
373
- label="Select Example Videos",
374
- value="car-turn")
375
- load_video_button = gr.Button("Load Video")
376
- gr.Markdown("### Step 2. write text prompt and advanced options for background and foreground.")
377
- with gr.Row():
378
- f_prompt = gr.Textbox(label="Foreground Prompt", value="a picture of an orange suv")
379
- b_prompt = gr.Textbox(label="Background Prompt", value="winter scene, snowy scene, beautiful snow")
380
- with gr.Row():
381
- with gr.Accordion("Advanced Foreground Options", open=False):
382
- adv_keyframes = gr.Textbox(label="keyframe", value="20, 40, 60")
383
- adv_atlas_resolution = gr.Slider(label="Atlas Resolution", minimum=1000, maximum=3000, value=2000, step=100)
384
- adv_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
385
- adv_low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
386
- adv_high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
387
- adv_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
388
- adv_s = gr.Slider(label="Noise Scale", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
389
- adv_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=15.0, value=9.0, step=0.1)
390
- adv_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
391
- adv_eta = gr.Number(label="eta (DDIM)", value=0.0)
392
- adv_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, no background')
393
- adv_n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
394
- adv_if_net = gr.gradio.Checkbox(label="if use agg net", value=False)
395
-
396
- with gr.Accordion("Background Options", open=False):
397
- b_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
398
- b_detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=512, step=1)
399
- b_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
400
- b_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
401
- b_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
402
- b_eta = gr.Number(label="eta (DDIM)", value=0.0)
403
- b_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
404
- b_n_prompt = gr.Textbox(label="Negative Prompt",
405
- value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
406
- gr.Markdown("### Step 3. edit each one and render.")
407
- with gr.Row():
408
- f_advance_run_button = gr.Button("Advanced Edit Foreground (slower, better)")
409
- b_run_button = gr.Button("Edit Background")
410
- run_button = gr.Button("Render")
411
- with gr.Column():
412
- output_video = gr.Video(label="Output Video", interactive=False)
413
- # output_foreground_atlas = gr.Image(label="Output Foreground Atlas", type="pil", interactive=False)
414
- output_foreground_atlas = gr.ImageMask(label="Editable Output Foreground Atlas", type="pil", tool="sketch", interactive=True)
415
- output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
416
-
417
- # edit param
418
- f_adv_edit_param = [adv_keyframes,
419
- adv_atlas_resolution,
420
- f_prompt,
421
- adv_a_prompt,
422
- adv_n_prompt,
423
- adv_image_resolution,
424
- adv_low_threshold,
425
- adv_high_threshold,
426
- adv_ddim_steps,
427
- adv_s,
428
- adv_scale,
429
- adv_seed,
430
- adv_eta,
431
- adv_if_net]
432
- b_edit_param = [b_prompt,
433
- b_a_prompt,
434
- b_n_prompt,
435
- b_image_resolution,
436
- b_detect_resolution,
437
- b_ddim_steps,
438
- b_scale,
439
- b_seed,
440
- b_eta]
441
- # action
442
- load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
443
- f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
444
- b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
445
- run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
446
 
447
- block.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ):
49
  self.apply_canny = CannyDetector()
50
  canny_model = create_model(base_cfg).cpu()
51
+ canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='cpu'), strict=False)
52
  self.canny_ddim_sampler = DDIMSampler(canny_model)
53
  self.canny_model = canny_model
54
 
 
59
  ):
60
  self.apply_midas = MidasDetector()
61
  depth_model = create_model(base_cfg).cpu()
62
+ depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='cpu'), strict=False)
63
  self.depth_ddim_sampler = DDIMSampler(depth_model)
64
  self.depth_model = depth_model
65
 
 
101
 
102
  detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
103
 
104
+ control = torch.from_numpy(detected_map.copy()).float() / 255.0
105
  control = torch.stack([control for _ in range(1)], dim=0)
106
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
107
 
 
128
 
129
  @torch.no_grad()
130
  def edit_background(self, *args, **kwargs):
131
+ self.depth_model = self.depth_model
132
 
133
  input_image = self.b_atlas_origin
134
  self.depth_edit(input_image, *args, **kwargs)
 
155
  if_net=False,
156
  num_samples=1):
157
 
158
+ self.canny_model = self.canny_model
159
 
160
  keyframes = [int(x) for x in keyframes.split(",")]
161
  if self.data is None:
 
186
  # get canny control
187
  detected_map = self.apply_canny(img, low_threshold, high_threshold)
188
  detected_map = HWC3(detected_map)
189
+ control = torch.from_numpy(detected_map.copy()).float() / 255.0
190
  control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
191
 
192
  cond = {"c_concat": [control], "c_crossattn": c_crossattn}
 
195
 
196
  # if not the key frame, calculate the mapping from last atlas
197
  if i == 0:
198
+ latent = torch.randn((1, 4, H // 8, W // 8))
199
  samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
200
  shape, cond, verbose=False, eta=eta,
201
  unconditional_guidance_scale=scale,
 
209
  mapped_img = mapped_img.resize((W, H))
210
  mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
211
  mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
212
+ mapped_img = torch.from_numpy(mapped_img)
213
  mapped_img = 2. * mapped_img - 1.
214
  latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
215
 
 
232
  result = alpha * result
233
 
234
  # buffer for training
235
+ result_copy = result.clone()
236
  result_copy.requires_grad = True
237
  result_list.append(result_copy)
238
 
 
259
  # aggregate net #
260
  #####################################
261
  lr, n_epoch = 1e-3, 500
262
+ agg_net = AGGNet()
263
  loss_fn = nn.L1Loss()
264
  optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
265
  for _ in range(n_epoch):
 
291
  def render(self, f_atlas, b_atlas):
292
  # foreground
293
  if f_atlas == None:
294
+ f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
295
  else:
296
  f_atlas, mask = f_atlas["image"], f_atlas["mask"]
297
+ f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
298
+ f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0)
299
+ mask = transforms.ToTensor()(mask).unsqueeze(0)
300
  if f_atlas.shape != mask.shape:
301
  print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
302
  mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
 
326
  if b_atlas == None:
327
  b_atlas = self.b_atlas_origin
328
 
329
+ b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0)
330
  background_edit = F.grid_sample(
331
  b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
332
  ).clamp(min=0.0, max=1.0)
 
349
  return save_name
350
 
351
  if __name__ == '__main__':
352
+ stablevideo = StableVideo(base_cfg="ckpt/cldm_v15.yaml",
353
+ canny_model_cfg="ckpt/control_sd15_canny.pth",
354
+ depth_model_cfg="ckpt/control_sd15_depth.pth",
355
+ save_memory=True)
356
+ stablevideo.load_canny_model()
357
+ stablevideo.load_depth_model()
358
+
359
+ block = gr.Blocks().queue()
360
+ with block:
361
+ with gr.Row():
362
+ gr.Markdown("## StableVideo")
363
+ with gr.Row():
364
+ with gr.Column():
365
+ original_video = gr.Video(label="Original Video", interactive=False)
366
+ with gr.Row():
367
+ foreground_atlas = gr.Image(label="Foreground Atlas", type="pil")
368
+ background_atlas = gr.Image(label="Background Atlas", type="pil")
369
+ gr.Markdown("### Step 1. select one example video and click **Load Video** buttom and wait for 10 sec.")
370
+ avail_video = [f.name for f in os.scandir("data") if f.is_dir()]
371
+ video_name = gr.Radio(choices=avail_video,
372
+ label="Select Example Videos",
373
+ value="car-turn")
374
+ load_video_button = gr.Button("Load Video")
375
+ gr.Markdown("### Step 2. write text prompt and advanced options for background and foreground.")
376
+ with gr.Row():
377
+ f_prompt = gr.Textbox(label="Foreground Prompt", value="a picture of an orange suv")
378
+ b_prompt = gr.Textbox(label="Background Prompt", value="winter scene, snowy scene, beautiful snow")
379
+ with gr.Row():
380
+ with gr.Accordion("Advanced Foreground Options", open=False):
381
+ adv_keyframes = gr.Textbox(label="keyframe", value="20, 40, 60")
382
+ adv_atlas_resolution = gr.Slider(label="Atlas Resolution", minimum=1000, maximum=3000, value=2000, step=100)
383
+ adv_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
384
+ adv_low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
385
+ adv_high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
386
+ adv_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
387
+ adv_s = gr.Slider(label="Noise Scale", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
388
+ adv_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=15.0, value=9.0, step=0.1)
389
+ adv_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
390
+ adv_eta = gr.Number(label="eta (DDIM)", value=0.0)
391
+ adv_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, no background')
392
+ adv_n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
393
+ adv_if_net = gr.gradio.Checkbox(label="if use agg net", value=False)
394
+
395
+ with gr.Accordion("Background Options", open=False):
396
+ b_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
397
+ b_detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=512, step=1)
398
+ b_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
399
+ b_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
400
+ b_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
401
+ b_eta = gr.Number(label="eta (DDIM)", value=0.0)
402
+ b_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
403
+ b_n_prompt = gr.Textbox(label="Negative Prompt",
404
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
405
+ gr.Markdown("### Step 3. edit each one and render.")
406
+ with gr.Row():
407
+ f_advance_run_button = gr.Button("Advanced Edit Foreground (slower, better)")
408
+ b_run_button = gr.Button("Edit Background")
409
+ run_button = gr.Button("Render")
410
+ with gr.Column():
411
+ output_video = gr.Video(label="Output Video", interactive=False)
412
+ # output_foreground_atlas = gr.Image(label="Output Foreground Atlas", type="pil", interactive=False)
413
+ output_foreground_atlas = gr.ImageMask(label="Editable Output Foreground Atlas", type="pil", tool="sketch", interactive=True)
414
+ output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
+ # edit param
417
+ f_adv_edit_param = [adv_keyframes,
418
+ adv_atlas_resolution,
419
+ f_prompt,
420
+ adv_a_prompt,
421
+ adv_n_prompt,
422
+ adv_image_resolution,
423
+ adv_low_threshold,
424
+ adv_high_threshold,
425
+ adv_ddim_steps,
426
+ adv_s,
427
+ adv_scale,
428
+ adv_seed,
429
+ adv_eta,
430
+ adv_if_net]
431
+ b_edit_param = [b_prompt,
432
+ b_a_prompt,
433
+ b_n_prompt,
434
+ b_image_resolution,
435
+ b_detect_resolution,
436
+ b_ddim_steps,
437
+ b_scale,
438
+ b_seed,
439
+ b_eta]
440
+ # action
441
+ load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
442
+ f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
443
+ b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
444
+ run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
445
+
446
+ block.launch()
ckpt/cldm_v15.yaml CHANGED
@@ -77,3 +77,5 @@ model:
77
 
78
  cond_stage_config:
79
  target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
 
 
 
77
 
78
  cond_stage_config:
79
  target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
80
+ params:
81
+ device: "cpu"
requirements.txt CHANGED
@@ -120,3 +120,4 @@ wcwidth==0.2.6
120
  websockets==11.0.3
121
  Werkzeug==2.3.7
122
  yarl==1.9.2
 
 
120
  websockets==11.0.3
121
  Werkzeug==2.3.7
122
  yarl==1.9.2
123
+ xformers
stablevideo/atlas_data.py CHANGED
@@ -30,7 +30,7 @@ class AtlasData():
30
  maximum_number_of_frames = json_dict["maximum_number_of_frames"]
31
 
32
  config = {
33
- "device": "cuda",
34
  "checkpoint_path": f"data/{video_name}/checkpoint.ckpt",
35
  "resx": json_dict["resx"],
36
  "resy": json_dict["resy"],
 
30
  maximum_number_of_frames = json_dict["maximum_number_of_frames"]
31
 
32
  config = {
33
+ "device": "cpu",
34
  "checkpoint_path": f"data/{video_name}/checkpoint.ckpt",
35
  "resx": json_dict["resx"],
36
  "resy": json_dict["resy"],
stablevideo/atlas_utils.py CHANGED
@@ -72,7 +72,7 @@ def load_neural_atlases_models(config):
72
  skip_layers=[],
73
  ).to(config["device"])
74
 
75
- checkpoint = torch.load(config["checkpoint_path"])
76
  foreground_mapping.load_state_dict(checkpoint["model_F_mapping1_state_dict"])
77
  background_mapping.load_state_dict(checkpoint["model_F_mapping2_state_dict"])
78
  foreground_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"])
 
72
  skip_layers=[],
73
  ).to(config["device"])
74
 
75
+ checkpoint = torch.load(config["checkpoint_path"], map_location=torch.device('cpu'))
76
  foreground_mapping.load_state_dict(checkpoint["model_F_mapping1_state_dict"])
77
  background_mapping.load_state_dict(checkpoint["model_F_mapping2_state_dict"])
78
  foreground_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"])