File size: 3,643 Bytes
fbd466e
 
 
 
 
 
 
 
 
8024618
 
 
fbd466e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8024618
 
 
 
 
 
fbd466e
327e1e8
8024618
fbd466e
 
 
 
 
 
 
8024618
 
 
 
 
fbd466e
8024618
fbd466e
8024618
 
fbd466e
 
 
 
8024618
 
12e4a4d
 
fbd466e
8024618
fbd466e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import numpy as np
from PIL import Image
import cv2
from moviepy.editor import VideoFileClip
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import export_to_video

SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
DESCRIPTION = 'This space is an API service meant to be used by VideoChain and VideoQuest.\nWant to use this space for yourself? Please use the original code: [https://huggingface.co/spaces/fffiloni/zeroscope-XL](https://huggingface.co/spaces/fffiloni/zeroscope-XL)'

pipe_xl = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/17")
pipe_xl.vae.enable_slicing()
pipe_xl.scheduler = DPMSolverMultistepScheduler.from_config(pipe_xl.scheduler.config)
pipe_xl.enable_model_cpu_offload()
pipe_xl.to("cuda")

def convert_mp4_to_frames(video_path, duration=3):
    # Read the video file
    video = cv2.VideoCapture(video_path)

    # Get the frames per second (fps) of the video
    fps = video.get(cv2.CAP_PROP_FPS)

    # Calculate the number of frames to extract
    num_frames = int(fps * duration)

    frames = []
    frame_count = 0
    
    # Iterate through each frame
    while True:
        # Read a frame
        ret, frame = video.read()
        
        # If the frame was not successfully read or we have reached the desired duration, break the loop
        if not ret or frame_count == num_frames:
            break
        
        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Append the frame to the list of frames
        frames.append(frame)

        frame_count += 1

    # Release the video object
    video.release()

    # Convert the list of frames to a numpy array
    frames = np.array(frames)

    return frames

def infer(prompt, video_in, denoise_strength, duration, secret_token: str = '') -> str:
    if secret_token != SECRET_TOKEN:
        raise gr.Error(
            f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
        
    negative_prompt = "text, watermark, copyright, blurry, cropped, noisy, pixelated, nsfw"

    # we cannot go beyond 3 seconds on the large A10G
    video = convert_mp4_to_frames(video_in, min(duration, 3))
    video_resized = [Image.fromarray(frame).resize((1024, 576)) for frame in video]
    video_frames = pipe_xl(prompt, negative_prompt=negative_prompt, video=video_resized, strength=denoise_strength).frames
    video_path = export_to_video(video_frames, output_video_path="xl_result.mp4")
    
    return "xl_result.mp4", gr.Group.update(visible=True)


with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    
    with gr.Column():
        secret_token = gr.Text(label='Secret Token', max_lines=1)
        video_in = gr.Video(type="numpy", source="upload")
        prompt_in = gr.Textbox(label="Prompt", elem_id="prompt-in")
        denoise_strength = gr.Slider(label="Denoise strength", minimum=0.6, maximum=0.9, step=0.01, value=0.66)
        duration = gr.Slider(label="Duration", minimum=0.5, maximum=3, step=0.5, value=3)
        #inference_steps = gr.Slider(label="Inference Steps", minimum=7, maximum=100, step=1, value=40, interactive=False)
        submit_btn = gr.Button("Submit")
        video_result = gr.Video(label="Video Output", elem_id="video-output")

    submit_btn.click(fn=infer,
                    inputs=[prompt_in, video_in, denoise_strength, secret_token],
                    outputs=[video_result],
                    api_name="zero_xl"
                    )
    
demo.queue(max_size=6).launch()