OpenSound commited on
Commit
c90e89d
1 Parent(s): eb39bd3
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -31,7 +31,7 @@ def load_models(config_name, ckpt_path, vae_path, device):
31
 
32
  # Load main U-Net model
33
  unet = MaskDiT(**params['model']).to(device)
34
- unet.load_state_dict(torch.load(ckpt_path)['model'])
35
  unet.eval()
36
 
37
  accelerator = Accelerator(mixed_precision="fp16")
 
31
 
32
  # Load main U-Net model
33
  unet = MaskDiT(**params['model']).to(device)
34
+ unet.load_state_dict(torch.load(ckpt_path, map_location='cpu')['model'])
35
  unet.eval()
36
 
37
  accelerator = Accelerator(mixed_precision="fp16")