File size: 3,670 Bytes
fbd466e
 
1c137a4
fbd466e
 
 
 
 
 
 
8024618
 
 
fbd466e
 
 
 
 
 
 
 
 
 
 
 
 
 
0623a6a
 
fbd466e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8024618
 
 
 
 
 
fbd466e
0623a6a
 
fbd466e
 
 
 
 
 
 
8024618
 
 
 
 
fbd466e
8024618
fbd466e
8024618
 
fbd466e
 
 
 
0623a6a
8024618
12e4a4d
 
fbd466e
8024618
fbd466e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import numpy as np
import os
from PIL import Image
import cv2
from moviepy.editor import VideoFileClip
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import export_to_video

SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
DESCRIPTION = 'This space is an API service meant to be used by VideoChain and VideoQuest.\nWant to use this space for yourself? Please use the original code: [https://huggingface.co/spaces/fffiloni/zeroscope-XL](https://huggingface.co/spaces/fffiloni/zeroscope-XL)'

pipe_xl = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/17")
pipe_xl.vae.enable_slicing()
pipe_xl.scheduler = DPMSolverMultistepScheduler.from_config(pipe_xl.scheduler.config)
pipe_xl.enable_model_cpu_offload()
pipe_xl.to("cuda")

def convert_mp4_to_frames(video_path, duration=3):
    # Read the video file
    video = cv2.VideoCapture(video_path)

    # Get the frames per second (fps) of the video
    fps = video.get(cv2.CAP_PROP_FPS)

    # Calculate the number of frames to extract
    # Note: we cannot go beyond 3 seconds on the large A10G
    num_frames = int(fps * min(duration, 3))

    frames = []
    frame_count = 0
    
    # Iterate through each frame
    while True:
        # Read a frame
        ret, frame = video.read()
        
        # If the frame was not successfully read or we have reached the desired duration, break the loop
        if not ret or frame_count == num_frames:
            break
        
        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Append the frame to the list of frames
        frames.append(frame)

        frame_count += 1

    # Release the video object
    video.release()

    # Convert the list of frames to a numpy array
    frames = np.array(frames)

    return frames

def infer(prompt, video_in, denoise_strength, duration, secret_token: str = '') -> str:
    if secret_token != SECRET_TOKEN:
        raise gr.Error(
            f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
        
    negative_prompt = "text, watermark, copyright, blurry, cropped, noisy, pixelated, nsfw"


    video = convert_mp4_to_frames(video_in, duration)
    video_resized = [Image.fromarray(frame).resize((1024, 576)) for frame in video]
    video_frames = pipe_xl(prompt, negative_prompt=negative_prompt, video=video_resized, strength=denoise_strength).frames
    video_path = export_to_video(video_frames, output_video_path="xl_result.mp4")
    
    return "xl_result.mp4", gr.Group.update(visible=True)


with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    
    with gr.Column():
        secret_token = gr.Text(label='Secret Token', max_lines=1)
        video_in = gr.Video(type="numpy", source="upload")
        prompt_in = gr.Textbox(label="Prompt", elem_id="prompt-in")
        denoise_strength = gr.Slider(label="Denoise strength", minimum=0.6, maximum=0.9, step=0.01, value=0.66)
        duration = gr.Slider(label="Duration", minimum=0.5, maximum=3, step=0.5, value=3)
        #inference_steps = gr.Slider(label="Inference Steps", minimum=7, maximum=100, step=1, value=40, interactive=False)
        submit_btn = gr.Button("Submit")
        video_result = gr.Video(label="Video Output", elem_id="video-output")

    submit_btn.click(fn=infer,
                    inputs=[prompt_in, video_in, denoise_strength, duration, secret_token],
                    outputs=[video_result],
                    api_name="zero_xl"
                    )
    
demo.queue(max_size=6).launch()