# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import gc import os.path as osp import os import sys import warnings import gradio as gr warnings.filterwarnings('ignore') # Model sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_video # Global Var prompt_expander = None wan_i2v_480P = None wan_i2v_720P = None # Button Func def load_model(value): global wan_i2v_480P, wan_i2v_720P if value == '------': print("No model loaded") return '------' if value == '720P': if args.ckpt_dir_720p is None: print("Please specify the checkpoint directory for 720P model") return '------' if wan_i2v_720P is not None: pass else: del wan_i2v_480P gc.collect() wan_i2v_480P = None print("load 14B-720P i2v model...", end='', flush=True) cfg = WAN_CONFIGS['i2v-14B'] wan_i2v_720P = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir_720p, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("done", flush=True) return '720P' if value == '480P': if args.ckpt_dir_480p is None: print("Please specify the checkpoint directory for 480P model") return '------' if wan_i2v_480P is not None: pass else: del wan_i2v_720P gc.collect() wan_i2v_720P = None print("load 14B-480P i2v model...", end='', flush=True) cfg = WAN_CONFIGS['i2v-14B'] wan_i2v_480P = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir_480p, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("done", flush=True) return '480P' def prompt_enc(prompt, img, tar_lang): print('prompt extend...') if img is None: print('Please upload an image') return prompt global prompt_expander prompt_output = prompt_expander( prompt, image=img, tar_lang=tar_lang.lower()) if prompt_output.status == False: return prompt else: return prompt_output.prompt def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt): # print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") if resolution == '------': print( 'Please specify at least one resolution ckpt dir or specify the resolution' ) return None else: if resolution == '720P': global wan_i2v_720P video = wan_i2v_720P.generate( img2vid_prompt, img2vid_image, max_area=MAX_AREA_CONFIGS['720*1280'], shift=shift_scale, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=True) else: global wan_i2v_480P video = wan_i2v_480P.generate( img2vid_prompt, img2vid_image, max_area=MAX_AREA_CONFIGS['480*832'], shift=shift_scale, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=True) cache_video( tensor=video[None], save_file="example.mp4", fps=16, nrow=1, normalize=True, value_range=(-1, 1)) return "example.mp4" # Interface def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("""