jhj0517 commited on
Commit
d0257e3
1 Parent(s): a0f5e02
Files changed (3) hide show
  1. app.py +1 -1
  2. musepose_inference.py +2 -2
  3. pose_align.py +4 -3
app.py CHANGED
@@ -138,7 +138,7 @@ class App:
138
  if __name__ == "__main__":
139
  parser = argparse.ArgumentParser()
140
  parser.add_argument('--model_dir', type=str, default=os.path.join("pretrained_weights"), help='Pretrained models directory for MusePose')
141
- parser.add_argument('--output_dir', type=str, default=os.path.join("assets", "videos"), help='Output directory for the result')
142
  parser.add_argument('--disable_model_download_at_start', type=bool, default=False, nargs='?', const=True, help='Disable model download at start or not')
143
  args = parser.parse_args()
144
 
 
138
  if __name__ == "__main__":
139
  parser = argparse.ArgumentParser()
140
  parser.add_argument('--model_dir', type=str, default=os.path.join("pretrained_weights"), help='Pretrained models directory for MusePose')
141
+ parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Output directory for the result')
142
  parser.add_argument('--disable_model_download_at_start', type=bool, default=False, nargs='?', const=True, help='Disable model download at start or not')
143
  args = parser.parse_args()
144
 
musepose_inference.py CHANGED
@@ -83,8 +83,8 @@ class MusePoseInference:
83
  image_file_name = os.path.splitext(os.path.basename(ref_image_path))[0]
84
  pose_video_file_name = os.path.splitext(os.path.basename(pose_video_path))[0]
85
  output_file_name = f"img_{image_file_name}_pose_{pose_video_file_name}"
86
- output_path = os.path.abspath(os.path.join(self.output_dir, f'{output_file_name}.mp4'))
87
- output_path_demo = os.path.abspath(os.path.join(self.output_dir, f'{output_file_name}_demo.mp4'))
88
 
89
  if weight_dtype == "fp16":
90
  weight_dtype = torch.float16
 
83
  image_file_name = os.path.splitext(os.path.basename(ref_image_path))[0]
84
  pose_video_file_name = os.path.splitext(os.path.basename(pose_video_path))[0]
85
  output_file_name = f"img_{image_file_name}_pose_{pose_video_file_name}"
86
+ output_path = os.path.abspath(os.path.join(self.output_dir, "musepose_inference", f'{output_file_name}.mp4'))
87
+ output_path_demo = os.path.abspath(os.path.join(self.output_dir, "musepose_inference", f'{output_file_name}_demo.mp4'))
88
 
89
  if weight_dtype == "fp16":
90
  weight_dtype = torch.float16
pose_align.py CHANGED
@@ -46,9 +46,10 @@ class PoseAlignmentInference:
46
  max_frame: int,
47
  ):
48
  download_models(model_dir=self.model_dir)
49
- dt_file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
50
- outfn=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}_demo.mp4'))
51
- outfn_align_pose_video=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}.mp4'))
 
52
 
53
  video = cv2.VideoCapture(vidfn)
54
  width= video.get(cv2.CAP_PROP_FRAME_WIDTH)
 
46
  max_frame: int,
47
  ):
48
  download_models(model_dir=self.model_dir)
49
+ img_file_name = os.path.splitext(os.path.basename(imgfn_refer))[0]
50
+ vid_file_name = os.path.splitext(os.path.basename(vidfn))[0]
51
+ outfn=os.path.abspath(os.path.join(self.output_dir, "pose_alignment", f'img_{img_file_name}_vid_{vid_file_name}_demo.mp4'))
52
+ outfn_align_pose_video=os.path.abspath(os.path.join(self.output_dir, "pose_alignment", f'img_{img_file_name}_vid_{vid_file_name}.mp4'))
53
 
54
  video = cv2.VideoCapture(vidfn)
55
  width= video.get(cv2.CAP_PROP_FRAME_WIDTH)