File size: 4,338 Bytes
a72119e
 
 
 
 
 
496112d
 
 
 
8365126
 
 
a72119e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de54836
4902bd9
496112d
 
 
4902bd9
 
 
 
 
 
 
 
 
 
 
496112d
4902bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8365126
 
4902bd9
 
 
 
8365126
4902bd9
 
 
 
 
 
 
 
de54836
2189235
a72119e
 
 
 
 
 
 
4902bd9
 
2189235
 
4902bd9
 
496112d
4902bd9
 
496112d
2189235
4902bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2189235
4902bd9
 
 
2189235
a72119e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import moviepy.editor as mp
from pydub import AudioSegment
from PIL import Image
import numpy as np
import os
import tempfile
import uuid

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

@spaces.GPU
def fn(vid, fps, color, progress=gr.Progress()):
    # Load the video using moviepy
    video = mp.VideoFileClip(vid)
    audio = video.audio
    frames = list(video.iter_frames(fps=fps))
    total_frames = len(frames)
    
    processed_frames_no_bg = []
    processed_frames_changed_bg = []
    
    # Create a live preview state
    preview_no_bg = None
    preview_with_bg = None
    
    for idx, frame in enumerate(progress.tqdm(frames)):
        pil_image = Image.fromarray(frame)
        processed_image, mask = process(pil_image, color)
        
        processed_frames_no_bg.append(np.array(processed_image))
        
        background = Image.new("RGBA", pil_image.size, color + (255,))
        composed_image = Image.composite(pil_image, background, mask)
        processed_frames_changed_bg.append(np.array(composed_image))
        
        # Update preview every 10 frames or on the last frame
        if idx % 10 == 0 or idx == total_frames - 1:
            preview_no_bg = np.array(processed_image)
            preview_with_bg = np.array(composed_image)
            yield preview_no_bg, preview_with_bg, None, None
    
    # Create videos from processed frames
    temp_dir = "temp"
    os.makedirs(temp_dir, exist_ok=True)
    
    processed_video = mp.ImageSequenceClip(processed_frames_changed_bg, fps=fps)
    processed_video = processed_video.set_audio(audio)
    temp_filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.mp4")
    processed_video.write_videofile(temp_filepath, codec="libx264")
    
    processed_video_no_bg = mp.ImageSequenceClip(processed_frames_no_bg, fps=fps)
    processed_video_no_bg = processed_video_no_bg.set_audio(audio)
    temp_filepath_no_bg = os.path.join(temp_dir, f"{uuid.uuid4()}.webm")
    processed_video_no_bg.write_videofile(temp_filepath_no_bg, codec="libvpx")
    
    # Final yield with completed videos
    yield None, None, temp_filepath_no_bg, temp_filepath

def process(image, color_hex):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    
    color_rgb = tuple(int(color_hex[i:i + 2], 16) for i in (1, 3, 5))
    background = Image.new("RGBA", image_size, color_rgb + (255,))
    image = Image.composite(image, background, mask)
    
    return image, mask

def change_color(in_video, fps_slider, color_picker):
    return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)

with gr.Blocks() as demo:
    with gr.Row():
        in_video = gr.Video(label="Input Video")
        no_bg_video = gr.Video(label="No BG Video", visible=True)
        out_video = gr.Video(label="Output Video", visible=True)
    
    with gr.Row(visible=False) as preview_row:
        preview_no_bg = gr.Image(label="Live Preview (No Background)", visible=True)
        preview_with_bg = gr.Image(label="Live Preview (With Background)", visible=True)
    
    with gr.Row():
        fps_slider = gr.Slider(minimum=1, maximum=60, step=1, value=12, label="Output FPS")
        color_picker = gr.ColorPicker(label="Background Color", value="#00FF00")
    
    submit_button = gr.Button("Change Background")
    
    # Handle visibility changes and processing
    submit_button.click(
        fn=fn,
        inputs=[in_video, fps_slider, color_picker],
        outputs=[preview_no_bg, preview_with_bg, no_bg_video, out_video]
    )

if __name__ == "__main__":
    demo.launch(show_error=True)