File size: 3,354 Bytes
4e424ea
ca753f0
4e424ea
7bedcdd
4e424ea
 
 
 
 
 
 
 
ca753f0
f0f4c78
4e424ea
f0f4c78
4e424ea
 
 
 
 
f0f4c78
4e424ea
 
 
f0f4c78
 
 
 
 
 
 
 
4e424ea
f20624c
 
 
 
 
 
ca753f0
 
 
f20624c
6c641ac
 
 
f20624c
 
6c641ac
ca753f0
f20624c
ca753f0
 
f20624c
 
 
 
 
ca753f0
f20624c
 
 
 
 
 
 
f0f4c78
f20624c
 
 
4f2bf09
f0f4c78
4e424ea
d701afa
4e424ea
 
f0f4c78
4e424ea
 
 
 
c81f025
4e424ea
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import re 
import subprocess
from tqdm import tqdm
from huggingface_hub import snapshot_download

#Download model
snapshot_download(
    repo_id = "Wan-AI/Wan2.1-T2V-1.3B",
    local_dir = "./Wan2.1-T2V-1.3B"
)

def infer(prompt, progress=gr.Progress(track_tqdm=True)):

    command = [
        "python", "-u", "-m", "generate",  # using -u for unbuffered output and omitting .py extension
        "--task", "t2v-1.3B",
        "--size", "832*480",
        "--ckpt_dir", "./Wan2.1-T2V-1.3B",
        "--sample_shift", "8",
        "--sample_guide_scale", "6",
        "--prompt", prompt,
        "--save_file", "generated_video.mp4"
    ]

    # Start the process with unbuffered output and combine stdout and stderr.
    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1  # line-buffered
    )

    # This bar will track the generation progress (extracted from the stdout progress lines)
    gen_progress_bar = None
    # This bar will "simulate" a progress update for each log line (non-progress messages).
    # We start with a total of 0 and update its total dynamically.
    log_progress_bar = tqdm(total=0, desc="Logs", position=1, dynamic_ncols=True, leave=True)
    
    progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
    
    for line in iter(process.stdout.readline, ''):
        # Remove whitespace so we can check for empty lines.
        stripped_line = line.strip()
        if not stripped_line:
            continue
        
        # Check if the line matches the progress bar format from the external process.
        match = progress_pattern.search(stripped_line)
        if match:
            # Extract current step and total from the match.
            current = int(match.group(2))
            total = int(match.group(3))
            if gen_progress_bar is None:
                gen_progress_bar = tqdm(total=total, desc="Video Generation Progress", position=0, dynamic_ncols=True, leave=True)
            # Update generation progress (ensuring we only advance by the difference)
            gen_progress_bar.update(current - gen_progress_bar.n)
            gen_progress_bar.refresh()
        else:
            # For any log line that is not part of the progress output, update the fake log track.
            # Increase the total count by one and update one step.
            log_progress_bar.total += 1
            log_progress_bar.update(1)
            # Write the log line so it appears in order above the progress bars.
            tqdm.write(stripped_line)
    
    process.wait()
    if gen_progress_bar is not None:
        gen_progress_bar.close()
    log_progress_bar.close()

    if process.returncode == 0:
        print("Command executed successfully.")
        return "generated_video.mp4"
    else:
        print("Error executing command.")
        raise Exception("Error executing command")

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1")
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Submit")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn = infer,
        inputs = [prompt],
        outputs = [video_res]
    )

demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)