File size: 3,716 Bytes
a72119e
 
 
 
 
 
496112d
 
 
 
8365126
 
 
a72119e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d754a8
de54836
6d754a8
496112d
 
6d754a8
 
496112d
6d754a8
 
 
 
 
4902bd9
 
6d754a8
496112d
10b6ea4
 
 
6d754a8
 
4902bd9
6d754a8
 
4902bd9
6d754a8
 
 
 
 
 
8365126
6d754a8
 
4902bd9
 
6d754a8
4902bd9
6d754a8
 
 
de54836
2189235
a72119e
 
6d754a8
a72119e
 
 
 
 
6d754a8
 
4902bd9
6d754a8
 
2189235
6d754a8
 
2189235
6d754a8
 
 
 
2189235
4902bd9
 
6d754a8
 
 
4902bd9
 
 
6d754a8
 
2189235
9ac006e
2189235
a72119e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import moviepy.editor as mp
from pydub import AudioSegment
from PIL import Image
import numpy as np
import os
import tempfile
import uuid

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


@spaces.GPU
def fn(vid, fps, color):
    # Load the video using moviepy
    video = mp.VideoFileClip(vid)

    # Extract audio from the video
    audio = video.audio

    # Extract frames at the specified FPS
    frames = video.iter_frames(fps=fps)

    # Process each frame for background removal
    processed_frames_no_bg = []
    processed_frames_changed_bg = []
    for frame in frames:
        pil_image = Image.fromarray(frame)
        processed_image, mask = process(pil_image, color)  
        processed_frames_no_bg.append(np.array(mask))  
        processed_frames_changed_bg.append(np.array(processed_image))

    # Create a new video from the processed frames
    processed_video = mp.ImageSequenceClip(processed_frames_changed_bg, fps=fps)

    # Add the original audio back to the processed video
    processed_video = processed_video.set_audio(audio)

    # Save the processed video to a temporary file
    temp_dir = "temp"
    os.makedirs(temp_dir, exist_ok=True)
    unique_filename = str(uuid.uuid4()) + ".mp4"
    temp_filepath = os.path.join(temp_dir, unique_filename)
    processed_video.write_videofile(temp_filepath, codec="libx264")

    # Create and save no-background video
    processed_video_no_bg = mp.ImageSequenceClip(processed_frames_no_bg, fps=fps)
    processed_video_no_bg = processed_video_no_bg.set_audio(audio)
    temp_filepath_no_bg = os.path.join(temp_dir, str(uuid.uuid4()) + ".webm")
    processed_video_no_bg.write_videofile(temp_filepath_no_bg, codec="libvpx")

    return temp_filepath_no_bg, temp_filepath


def process(image, color_hex):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)

    # Convert hex color to RGB tuple
    color_rgb = tuple(int(color_hex[i:i + 2], 16) for i in (1, 3, 5))

    # Create a background image with the chosen color
    background = Image.new("RGBA", image_size, color_rgb + (255,))

    # Composite the image onto the background using the mask
    image = Image.composite(image, background, mask)

    return image, mask  # Return both the processed image and the mask


with gr.Blocks() as demo:
    with gr.Row():
        in_video = gr.Video(label="Input Video")
        no_bg_video = gr.Video(label="No BG Video")  # Added for no-background video
        out_video = gr.Video(label="Output Video")  # This will be the changed-background video
    submit_button = gr.Button("Change Background")
    with gr.Row():
        fps_slider = gr.Slider(minimum=1, maximum=60, step=1, value=12, label="Output FPS")
        color_picker = gr.ColorPicker(label="Background Color", value="#00FF00")


    submit_button.click(
        fn, inputs=[in_video, fps_slider, color_picker], outputs=[no_bg_video, out_video] 
    )

if __name__ == "__main__":
    demo.launch(show_error=True)