|
|
|
import argparse |
|
import gc |
|
import os.path as osp |
|
import os |
|
import sys |
|
import warnings |
|
|
|
import gradio as gr |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) |
|
import wan |
|
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS |
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander |
|
from wan.utils.utils import cache_video |
|
|
|
|
|
prompt_expander = None |
|
wan_i2v_480P = None |
|
wan_i2v_720P = None |
|
|
|
|
|
|
|
def load_model(value): |
|
global wan_i2v_480P, wan_i2v_720P |
|
|
|
if value == '------': |
|
print("No model loaded") |
|
return '------' |
|
|
|
if value == '720P': |
|
if args.ckpt_dir_720p is None: |
|
print("Please specify the checkpoint directory for 720P model") |
|
return '------' |
|
if wan_i2v_720P is not None: |
|
pass |
|
else: |
|
del wan_i2v_480P |
|
gc.collect() |
|
wan_i2v_480P = None |
|
|
|
print("load 14B-720P i2v model...", end='', flush=True) |
|
cfg = WAN_CONFIGS['i2v-14B'] |
|
wan_i2v_720P = wan.WanI2V( |
|
config=cfg, |
|
checkpoint_dir=args.ckpt_dir_720p, |
|
device_id=0, |
|
rank=0, |
|
t5_fsdp=False, |
|
dit_fsdp=False, |
|
use_usp=False, |
|
) |
|
print("done", flush=True) |
|
return '720P' |
|
|
|
if value == '480P': |
|
if args.ckpt_dir_480p is None: |
|
print("Please specify the checkpoint directory for 480P model") |
|
return '------' |
|
if wan_i2v_480P is not None: |
|
pass |
|
else: |
|
del wan_i2v_720P |
|
gc.collect() |
|
wan_i2v_720P = None |
|
|
|
print("load 14B-480P i2v model...", end='', flush=True) |
|
cfg = WAN_CONFIGS['i2v-14B'] |
|
wan_i2v_480P = wan.WanI2V( |
|
config=cfg, |
|
checkpoint_dir=args.ckpt_dir_480p, |
|
device_id=0, |
|
rank=0, |
|
t5_fsdp=False, |
|
dit_fsdp=False, |
|
use_usp=False, |
|
) |
|
print("done", flush=True) |
|
return '480P' |
|
|
|
|
|
def prompt_enc(prompt, img, tar_lang): |
|
print('prompt extend...') |
|
if img is None: |
|
print('Please upload an image') |
|
return prompt |
|
global prompt_expander |
|
prompt_output = prompt_expander( |
|
prompt, image=img, tar_lang=tar_lang.lower()) |
|
if prompt_output.status == False: |
|
return prompt |
|
else: |
|
return prompt_output.prompt |
|
|
|
|
|
def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps, |
|
guide_scale, shift_scale, seed, n_prompt): |
|
|
|
|
|
if resolution == '------': |
|
print( |
|
'Please specify at least one resolution ckpt dir or specify the resolution' |
|
) |
|
return None |
|
|
|
else: |
|
if resolution == '720P': |
|
global wan_i2v_720P |
|
video = wan_i2v_720P.generate( |
|
img2vid_prompt, |
|
img2vid_image, |
|
max_area=MAX_AREA_CONFIGS['720*1280'], |
|
shift=shift_scale, |
|
sampling_steps=sd_steps, |
|
guide_scale=guide_scale, |
|
n_prompt=n_prompt, |
|
seed=seed, |
|
offload_model=True) |
|
else: |
|
global wan_i2v_480P |
|
video = wan_i2v_480P.generate( |
|
img2vid_prompt, |
|
img2vid_image, |
|
max_area=MAX_AREA_CONFIGS['480*832'], |
|
shift=shift_scale, |
|
sampling_steps=sd_steps, |
|
guide_scale=guide_scale, |
|
n_prompt=n_prompt, |
|
seed=seed, |
|
offload_model=True) |
|
|
|
cache_video( |
|
tensor=video[None], |
|
save_file="example.mp4", |
|
fps=16, |
|
nrow=1, |
|
normalize=True, |
|
value_range=(-1, 1)) |
|
|
|
return "example.mp4" |
|
|
|
|
|
|
|
def gradio_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> |
|
Wan2.1 (I2V-14B) |
|
</div> |
|
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;"> |
|
Wan: Open and Advanced Large-Scale Video Generative Models. |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
resolution = gr.Dropdown( |
|
label='Resolution', |
|
choices=['------', '720P', '480P'], |
|
value='------') |
|
|
|
img2vid_image = gr.Image( |
|
type="pil", |
|
label="Upload Input Image", |
|
elem_id="image_upload", |
|
) |
|
img2vid_prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Describe the video you want to generate", |
|
) |
|
tar_lang = gr.Radio( |
|
choices=["CH", "EN"], |
|
label="Target language of prompt enhance", |
|
value="CH") |
|
run_p_button = gr.Button(value="Prompt Enhance") |
|
|
|
with gr.Accordion("Advanced Options", open=True): |
|
with gr.Row(): |
|
sd_steps = gr.Slider( |
|
label="Diffusion steps", |
|
minimum=1, |
|
maximum=1000, |
|
value=50, |
|
step=1) |
|
guide_scale = gr.Slider( |
|
label="Guide scale", |
|
minimum=0, |
|
maximum=20, |
|
value=5.0, |
|
step=1) |
|
with gr.Row(): |
|
shift_scale = gr.Slider( |
|
label="Shift scale", |
|
minimum=0, |
|
maximum=10, |
|
value=5.0, |
|
step=1) |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=-1, |
|
maximum=2147483647, |
|
step=1, |
|
value=-1) |
|
n_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
placeholder="Describe the negative prompt you want to add" |
|
) |
|
|
|
run_i2v_button = gr.Button("Generate Video") |
|
|
|
with gr.Column(): |
|
result_gallery = gr.Video( |
|
label='Generated Video', interactive=False, height=600) |
|
|
|
resolution.input( |
|
fn=load_model, inputs=[resolution], outputs=[resolution]) |
|
|
|
run_p_button.click( |
|
fn=prompt_enc, |
|
inputs=[img2vid_prompt, img2vid_image, tar_lang], |
|
outputs=[img2vid_prompt]) |
|
|
|
run_i2v_button.click( |
|
fn=i2v_generation, |
|
inputs=[ |
|
img2vid_prompt, img2vid_image, resolution, sd_steps, |
|
guide_scale, shift_scale, seed, n_prompt |
|
], |
|
outputs=[result_gallery], |
|
) |
|
|
|
return demo |
|
|
|
|
|
|
|
def _parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description="Generate a video from a text prompt or image using Gradio") |
|
parser.add_argument( |
|
"--ckpt_dir_720p", |
|
type=str, |
|
default=None, |
|
help="The path to the checkpoint directory.") |
|
parser.add_argument( |
|
"--ckpt_dir_480p", |
|
type=str, |
|
default=None, |
|
help="The path to the checkpoint directory.") |
|
parser.add_argument( |
|
"--prompt_extend_method", |
|
type=str, |
|
default="local_qwen", |
|
choices=["dashscope", "local_qwen"], |
|
help="The prompt extend method to use.") |
|
parser.add_argument( |
|
"--prompt_extend_model", |
|
type=str, |
|
default=None, |
|
help="The prompt extend model to use.") |
|
|
|
args = parser.parse_args() |
|
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory." |
|
|
|
return args |
|
|
|
|
|
if __name__ == '__main__': |
|
args = _parse_args() |
|
|
|
print("Step1: Init prompt_expander...", end='', flush=True) |
|
if args.prompt_extend_method == "dashscope": |
|
prompt_expander = DashScopePromptExpander( |
|
model_name=args.prompt_extend_model, is_vl=True) |
|
elif args.prompt_extend_method == "local_qwen": |
|
prompt_expander = QwenPromptExpander( |
|
model_name=args.prompt_extend_model, is_vl=True, device=0) |
|
else: |
|
raise NotImplementedError( |
|
f"Unsupport prompt_extend_method: {args.prompt_extend_method}") |
|
print("done", flush=True) |
|
|
|
demo = gradio_interface() |
|
demo.launch(server_name="0.0.0.0", share=False, server_port=7860) |
|
|