jhj0517 commited on
Commit
bd8868d
1 Parent(s): a0d251a

fix download_model

Browse files
Files changed (3) hide show
  1. app.py +0 -22
  2. musepose_inference.py +3 -2
  3. pose_align.py +3 -4
app.py CHANGED
@@ -11,26 +11,6 @@ class App:
11
  self.pose_alignment_infer = PoseAlignmentInference()
12
  self.musepose_infer = MusePoseInference()
13
 
14
- @staticmethod
15
- def download_models():
16
- repo_id = 'jhj0517/MusePose'
17
- model_paths = {
18
- "det_ckpt": os.path.join("pretrained_weights", "dwpose", "yolox_l_8x8_300e_coco.pth"),
19
- "pose_ckpt": os.path.join("pretrained_weights", "dwpose", "dw-ll_ucoco_384.pth")
20
- }
21
- for name, file_path in model_paths.items():
22
-
23
- local_dir, filename = os.path.dirname(file_path), os.path.basename(file_path)
24
- if not os.path.exists(local_dir):
25
- os.makedirs(local_dir)
26
-
27
- remote_filepath = f"dwpose/{filename}"
28
- if not os.path.exists(file_path):
29
- print(file_path)
30
- hf_hub_download(repo_id=repo_id, filename=remote_filepath,
31
- local_dir=local_dir,
32
- local_dir_use_symlinks=False)
33
-
34
  def musepose_demo(self):
35
  with gr.Blocks() as demo:
36
  with gr.Tabs():
@@ -51,13 +31,11 @@ class App:
51
 
52
  with gr.Row():
53
  btn_algin_pose = gr.Button("ALIGN POSE", variant="primary")
54
- btn_down = gr.Button("download", variant="primary")
55
 
56
  btn_algin_pose.click(fn=self.pose_alignment_infer.align_pose,
57
  inputs=[vid_dance_input, img_input, nb_detect_resolution, nb_image_resolution,
58
  nb_align_frame, nb_max_frame],
59
  outputs=[vid_dance_output, vid_dance_output_demo])
60
- btn_down.click(fn=self.download_models, inputs=None, outputs=None)
61
 
62
  with gr.TabItem('Step2: MusePose Inference'):
63
  with gr.Row():
 
11
  self.pose_alignment_infer = PoseAlignmentInference()
12
  self.musepose_infer = MusePoseInference()
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def musepose_demo(self):
15
  with gr.Blocks() as demo:
16
  with gr.Tabs():
 
31
 
32
  with gr.Row():
33
  btn_algin_pose = gr.Button("ALIGN POSE", variant="primary")
 
34
 
35
  btn_algin_pose.click(fn=self.pose_alignment_infer.align_pose,
36
  inputs=[vid_dance_input, img_input, nb_detect_resolution, nb_image_resolution,
37
  nb_align_frame, nb_max_frame],
38
  outputs=[vid_dance_output, vid_dance_output_demo])
 
39
 
40
  with gr.TabItem('Step2: MusePose Inference'):
41
  with gr.Row():
musepose_inference.py CHANGED
@@ -217,15 +217,16 @@ class MusePoseInference:
217
 
218
  def download_models(self):
219
  repo_id = 'jhj0517/MusePose'
 
220
  for name, file_path in self.musepose_model_paths.items():
221
  local_dir, filename = os.path.dirname(file_path), os.path.basename(file_path)
222
  if not os.path.exists(local_dir):
223
  os.makedirs(local_dir)
224
 
225
- remote_filepath = os.path.join("MusePose", filename)
226
  if not os.path.exists(file_path):
227
  hf_hub_download(repo_id=repo_id, filename=remote_filepath,
228
- local_dir=local_dir,
229
  local_dir_use_symlinks=False)
230
 
231
  def release_vram(self):
 
217
 
218
  def download_models(self):
219
  repo_id = 'jhj0517/MusePose'
220
+ local_model_dir = os.path.join("pretrained_weights")
221
  for name, file_path in self.musepose_model_paths.items():
222
  local_dir, filename = os.path.dirname(file_path), os.path.basename(file_path)
223
  if not os.path.exists(local_dir):
224
  os.makedirs(local_dir)
225
 
226
+ remote_filepath = f"MusePose/{filename}"
227
  if not os.path.exists(file_path):
228
  hf_hub_download(repo_id=repo_id, filename=remote_filepath,
229
+ local_dir=local_model_dir,
230
  local_dir_use_symlinks=False)
231
 
232
  def release_vram(self):
pose_align.py CHANGED
@@ -312,17 +312,16 @@ class PoseAlignmentInference:
312
 
313
  def download_models(self):
314
  repo_id = 'jhj0517/MusePose'
 
315
  for name, file_path in self.model_paths.items():
316
-
317
  local_dir, filename = os.path.dirname(file_path), os.path.basename(file_path)
318
  if not os.path.exists(local_dir):
319
  os.makedirs(local_dir)
320
 
321
- remote_filepath = os.path.join("dwpose", filename)
322
  if not os.path.exists(file_path):
323
- print(file_path)
324
  hf_hub_download(repo_id=repo_id, filename=remote_filepath,
325
- local_dir=local_dir,
326
  local_dir_use_symlinks=False)
327
 
328
  def release_vram(self):
 
312
 
313
  def download_models(self):
314
  repo_id = 'jhj0517/MusePose'
315
+ local_model_dir = os.path.join("pretrained_weights")
316
  for name, file_path in self.model_paths.items():
 
317
  local_dir, filename = os.path.dirname(file_path), os.path.basename(file_path)
318
  if not os.path.exists(local_dir):
319
  os.makedirs(local_dir)
320
 
321
+ remote_filepath = f"dwpose/{filename}"
322
  if not os.path.exists(file_path):
 
323
  hf_hub_download(repo_id=repo_id, filename=remote_filepath,
324
+ local_dir=local_model_dir,
325
  local_dir_use_symlinks=False)
326
 
327
  def release_vram(self):