File size: 8,812 Bytes
0a9bdfb 3c0f460 a874577 0a9bdfb 3c0f460 0a9bdfb a874577 0a9bdfb 3c0f460 902b0d6 3c0f460 699c0d5 0a9bdfb 7405324 0a9bdfb 7405324 0a9bdfb 7405324 0a9bdfb 86760f1 0a9bdfb 7405324 0a9bdfb 7405324 0a9bdfb 7405324 0a9bdfb 902b0d6 1190e23 3c0f460 d0257e3 699c0d5 902b0d6 3c0f460 1190e23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import argparse
import os
from musepose_inference import MusePoseInference
from pose_align import PoseAlignmentInference
from downloading_weights import download_models
class App:
def __init__(self, args):
self.args = args
self.pose_alignment_infer = PoseAlignmentInference(
model_dir=args.model_dir,
output_dir=args.output_dir
)
self.musepose_infer = MusePoseInference(
model_dir=args.model_dir,
output_dir=args.output_dir
)
if not args.disable_model_download_at_start:
download_models(model_dir=args.model_dir)
def musepose_demo(self):
with gr.Blocks() as demo:
md_header = self.header()
with gr.Tabs():
with gr.TabItem('Step1: Pose Alignment'):
with gr.Row():
with gr.Column(scale=3):
img_input = gr.Image(label="Input Image here", type="filepath", scale=5)
vid_dance_input = gr.Video(label="Input Dance Video", scale=5)
with gr.Column(scale=3):
vid_dance_output = gr.Video(label="Aligned pose output will be displayed here", scale=5)
vid_dance_output_demo = gr.Video(label="Output demo video will be displayed here", scale=5)
with gr.Column(scale=3):
with gr.Column():
nb_detect_resolution = gr.Number(label="Detect Resolution", value=512, precision=0)
nb_image_resolution = gr.Number(label="Image Resolution.", value=720, precision=0)
nb_align_frame = gr.Number(label="Align Frame", value=0, precision=0)
nb_max_frame = gr.Number(label="Max Frame", value=300, precision=0)
with gr.Row():
btn_align_pose = gr.Button("ALIGN POSE", variant="primary")
with gr.Column():
examples = [
[os.path.join("assets", "videos", "dance.mp4"), os.path.join("assets", "images", "ref.png"),
512, 720, 0, 300]]
ex_step1 = gr.Examples(examples=examples,
inputs=[vid_dance_input, img_input, nb_detect_resolution,
nb_image_resolution, nb_align_frame, nb_max_frame],
outputs=[vid_dance_output, vid_dance_output_demo],
fn=self.pose_alignment_infer.align_pose,
cache_examples="lazy")
btn_align_pose.click(fn=self.pose_alignment_infer.align_pose,
inputs=[vid_dance_input, img_input, nb_detect_resolution, nb_image_resolution,
nb_align_frame, nb_max_frame],
outputs=[vid_dance_output, vid_dance_output_demo])
with gr.TabItem('Step2: MusePose Inference'):
with gr.Row():
with gr.Column(scale=3):
img_input = gr.Image(label="Input Image here", type="filepath", scale=5)
vid_pose_input = gr.Video(label="Input Aligned Pose Video here", scale=5)
with gr.Column(scale=3):
vid_output = gr.Video(label="Output Video will be displayed here", scale=5)
vid_output_demo = gr.Video(label="Output demo video will be displayed here", scale=5)
with gr.Column(scale=3):
with gr.Column():
weight_dtype = gr.Dropdown(label="Compute Type", choices=["fp16", "fp32"],
value="fp16")
nb_width = gr.Number(label="Width.", value=512, precision=0)
nb_height = gr.Number(label="Height.", value=512, precision=0)
nb_video_frame_length = gr.Number(label="Video Frame Length", value=300, precision=0)
nb_video_slice_frame_length = gr.Number(label="Video Slice Frame Number ", value=48,
precision=0)
nb_video_slice_overlap_frame_number = gr.Number(
label="Video Slice Overlap Frame Number", value=4, precision=0)
nb_cfg = gr.Number(label="CFG (Classifier Free Guidance)", value=3.5, precision=0)
nb_seed = gr.Number(label="Seed", value=99, precision=0)
nb_steps = gr.Number(label="DDIM Sampling Steps", value=20, precision=0)
nb_fps = gr.Number(label="FPS (Frames Per Second) ", value=-1, precision=0,
info="Set to '-1' to use same FPS with pose's")
nb_skip = gr.Number(label="SKIP (Frame Sample Rate = SKIP+1)", value=1, precision=0)
with gr.Row():
btn_generate = gr.Button("GENERATE", variant="primary")
with gr.Column():
examples = [
[os.path.join("assets", "images", "ref.png"), os.path.join("assets", "videos", "pose.mp4"),
"fp16", 512, 512, 300, 48, 4, 3.5, 99, 20, -1, 1]]
ex_step2 = gr.Examples(examples=examples,
inputs=[img_input, vid_pose_input, weight_dtype, nb_width, nb_height,
nb_video_frame_length, nb_video_slice_frame_length,
nb_video_slice_overlap_frame_number, nb_cfg, nb_seed, nb_steps,
nb_fps, nb_skip],
outputs=[vid_output, vid_output_demo],
fn=self.musepose_infer.infer_musepose,
cache_examples="lazy")
btn_generate.click(fn=self.musepose_infer.infer_musepose,
inputs=[img_input, vid_pose_input, weight_dtype, nb_width, nb_height,
nb_video_frame_length, nb_video_slice_frame_length,
nb_video_slice_overlap_frame_number, nb_cfg, nb_seed, nb_steps, nb_fps,
nb_skip],
outputs=[vid_output, vid_output_demo])
return demo
@staticmethod
def header():
header = gr.HTML(
"""
<style>
p, li {
font-size: 16px;
}
</style>
<h2>Gradio demo for <a href="https://github.com/TMElyralab/MusePose">MusePose</a></h2>
<p>Demo list you can try in other environment:</p>
<ul>
<li><a href="https://github.com/jhj0517/MusePose-WebUI"><strong>MusePose WebUI</strong></a> (This repository, you can try in local)</li>
<li><a href="https://github.com/jhj0517/stable-diffusion-webui-MusePose.git"><strong>stable-diffusion-webui-MusePose</strong></a> (SD WebUI extension)</li>
<li><a href="https://github.com/TMElyralab/Comfyui-MusePose"><strong>Comfyui-MusePose</strong></a> (ComfyUI custom node)</li>
</ul>
"""
)
return header
def launch(self):
demo = self.musepose_demo()
demo.queue().launch(
share=self.args.share
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=os.path.join("pretrained_weights"), help='Pretrained models directory for MusePose')
parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Output directory for the result')
parser.add_argument('--disable_model_download_at_start', type=bool, default=False, nargs='?', const=True, help='Disable model download at start or not')
parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio makes sharable link if it is true')
args = parser.parse_args()
app = App(args=args)
app.launch() |