import spaces
import gradio as gr
import cv2
import numpy as np
import time
import random
from PIL import Image
import torch
import re

torch.jit.script = lambda f: f

from transparent_background import Remover

@spaces.GPU()
def doo(video, color, mode, progress=gr.Progress()):
    print(color)
    if color.startswith('#'):
        color = color.lstrip('#')
        rgb = tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
        color = str(list(rgb))
    elif color.startswith('rgba'):
        rgba_match = re.match(r'rgba\((\d+), (\d+), (\d+), ([\d.]+)\)', color)
        if rgba_match:
            r, g, b, _ = rgba_match.groups()
            color = str([int(r), int(g), int(b)])
    print(color)
    if mode == 'Fast':
        remover = Remover(mode='fast')
    else:
        remover = Remover()

    cap = cv2.VideoCapture(video)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))  # Get total frames
    writer = None
    tmpname = random.randint(111111111, 999999999)
    processed_frames = 0
    start_time = time.time()

    while cap.isOpened():
        ret, frame = cap.read()

        if ret is False:
            break

        if time.time() - start_time >= 20 * 60 - 5:
            print("GPU Timeout is coming")
            cap.release()
            writer.release()
            return str(tmpname) + '.mp4'
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame).convert('RGB')

        if writer is None:
            writer = cv2.VideoWriter(str(tmpname) + '.mp4', cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)

        processed_frames += 1
        print(f"Processing frame {processed_frames}")
        progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
        out = remover.process(img, type=color)
        writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB))

    cap.release()
    writer.release()
    return str(tmpname) + '.mp4'

title = "🎞️ Video Background Removal Tool 🎥"
description = """*Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode."""

examples = [['./input.mp4']]

iface = gr.Interface(
    fn=doo,
    inputs=["video", gr.ColorPicker(label="Background color", value="#00FF00"), gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.')],
    outputs="video",
    examples=examples,
    title=title,
    description=description
)
iface.launch()