jbilcke-hf's picture
jbilcke-hf HF staff
Update app.py
a0b90fa
raw
history blame
3.64 kB
import gradio as gr
import numpy as np
import os
from PIL import Image
import cv2
from moviepy.editor import VideoFileClip
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import export_to_video
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
DESCRIPTION = 'This space is an API service meant to be used by VideoChain and VideoQuest.\nWant to use this space for yourself? Please use the original code: [https://huggingface.co/spaces/fffiloni/zeroscope-XL](https://huggingface.co/spaces/fffiloni/zeroscope-XL)'
pipe_xl = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/17")
pipe_xl.vae.enable_slicing()
pipe_xl.scheduler = DPMSolverMultistepScheduler.from_config(pipe_xl.scheduler.config)
pipe_xl.enable_model_cpu_offload()
pipe_xl.to("cuda")
def convert_mp4_to_frames(video_path, duration=3):
# Read the video file
video = cv2.VideoCapture(video_path)
# Get the frames per second (fps) of the video
fps = video.get(cv2.CAP_PROP_FPS)
# Calculate the number of frames to extract
# Note: we cannot go beyond 3 seconds on the large A10G
num_frames = int(fps * min(duration, 3))
frames = []
frame_count = 0
# Iterate through each frame
while True:
# Read a frame
ret, frame = video.read()
# If the frame was not successfully read or we have reached the desired duration, break the loop
if not ret or frame_count == num_frames:
break
# Convert BGR to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Append the frame to the list of frames
frames.append(frame)
frame_count += 1
# Release the video object
video.release()
# Convert the list of frames to a numpy array
frames = np.array(frames)
return frames
def infer(prompt, video_in, denoise_strength, secret_token):
if secret_token != SECRET_TOKEN:
raise gr.Error(
f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
negative_prompt = "text, watermark, copyright, blurry, cropped, noisy, pixelated, nsfw"
video = convert_mp4_to_frames(video_in, duration)
video_resized = [Image.fromarray(frame).resize((1024, 576)) for frame in video]
video_frames = pipe_xl(prompt, negative_prompt=negative_prompt, video=video_resized, strength=denoise_strength).frames
video_path = export_to_video(video_frames, output_video_path="xl_result.mp4")
return "xl_result.mp4", gr.Group.update(visible=True)
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Column():
secret_token = gr.Text(label='Secret Token', max_lines=1)
video_in = gr.Video(type="numpy", source="upload")
prompt_in = gr.Textbox(label="Prompt", elem_id="prompt-in")
denoise_strength = gr.Slider(label="Denoise strength", minimum=0.6, maximum=0.9, step=0.01, value=0.66)
#duration = gr.Slider(label="Duration", minimum=0.5, maximum=3, step=0.5, value=3)
#inference_steps = gr.Slider(label="Inference Steps", minimum=7, maximum=100, step=1, value=40, interactive=False)
submit_btn = gr.Button("Submit")
video_result = gr.Video(label="Video Output", elem_id="video-output")
submit_btn.click(fn=infer,
inputs=[prompt_in, video_in, denoise_strength, duration, secret_token],
outputs=[video_result],
api_name="zero_xl"
)
demo.queue(max_size=6).launch()