xierui.0097 commited on
Commit
48c7a9e
·
1 Parent(s): e654179
video_to_video/video_to_video_model.py CHANGED
@@ -47,13 +47,19 @@ class VideoToVideo_sr():
47
  generator = generator.to(self.device)
48
  generator.eval()
49
 
 
50
  cfg.model_path = opt.model_path
51
  # download weight
52
  model_url = 'https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/heavy_deg.pt'
53
  download_model(model_url, cfg.model_path)
54
 
55
- # point to the weight
56
- load_dict = torch.load(cfg.model_path + '/heavy_deg.pt' , map_location='cpu')
 
 
 
 
 
57
  if 'state_dict' in load_dict:
58
  load_dict = load_dict['state_dict']
59
  ret = generator.load_state_dict(load_dict, strict=False)
 
47
  generator = generator.to(self.device)
48
  generator.eval()
49
 
50
+ # 确保 cfg.model_path 是文件夹路径,不要加上文件名
51
  cfg.model_path = opt.model_path
52
  # download weight
53
  model_url = 'https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/heavy_deg.pt'
54
  download_model(model_url, cfg.model_path)
55
 
56
+ # 拼接完整路径
57
+ model_file_path = os.path.join(cfg.model_path, 'heavy_deg.pt')
58
+ print('model_file_path:', model_file_path)
59
+
60
+ # 加载模型
61
+ load_dict = torch.load(model_file_path, map_location='cpu')
62
+
63
  if 'state_dict' in load_dict:
64
  load_dict = load_dict['state_dict']
65
  ret = generator.load_state_dict(load_dict, strict=False)