Wan2.1 / gradio /t2v_14B_singleGPU.py
fffiloni's picture
Migrated from GitHub
3964763 verified
raw
history blame
6.59 kB
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_t2v = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2v
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2v.generate(
txt2vid_prompt,
size=(W, H),
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2V-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["CH", "EN"],
label="Target language of prompt enhance",
value="CH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'720*1280', '1280*720', '960*960', '1088*832',
'832*1088', '480*832', '832*480', '624*624',
'704*544', '544*704'
],
value='720*1280')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2vid_prompt, tar_lang],
outputs=[txt2vid_prompt])
run_t2v_button.click(
fn=t2v_generation,
inputs=[
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=False)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=False, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
print("Step2: Init 14B t2v model...", end='', flush=True)
cfg = WAN_CONFIGS['t2v-14B']
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)