import numpy as np import torch.nn.functional as F from torchvision.transforms.functional import normalize from skimage import io import torch, os from PIL import Image from briarmbg import BriaRMBG import gradio as gr import cv2 import numpy as np import time import random from PIL import Image bgrm = BriaRMBG.from_pretrained("briaai/RMBG-1.4") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") bgrm.to(device) print("device:", device) def resize_image(image): image = image.convert('RGB') model_input_size = (1024, 1024) image = image.resize(model_input_size, Image.BILINEAR) return image def process(image): # prepare input orig_image = Image.fromarray(image) w,h = orig_im_size = orig_image.size image = resize_image(orig_image) im_np = np.array(image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) im_tensor = torch.unsqueeze(im_tensor,0) im_tensor = torch.divide(im_tensor,255.0) im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) if torch.cuda.is_available(): im_tensor=im_tensor.cuda() #inference result=bgrm(im_tensor) # post process result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) ma = torch.max(result) mi = torch.min(result) result = (result-mi)/(ma-mi) # image to pil im_array = (result*255).cpu().data.numpy().astype(np.uint8) pil_im = Image.fromarray(np.squeeze(im_array)) # paste the mask on the original image new_im = Image.new("RGBA", pil_im.size, (0,255,0,255)) new_im.paste(orig_image, mask=pil_im) # new_orig_image = orig_image.convert('RGBA') return new_im def process_video(video, progress=gr.Progress()): cap = cv2.VideoCapture(video) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Get total frames writer = None tmpname ='output.mp4' processed_frames = 0 start_time = time.time() i=0 while cap.isOpened(): ret, frame = cap.read() if ret is False: break if time.time() - start_time >= 20 * 60 - 5: print("GPU Timeout is coming") cap.release() writer.release() return tmpname frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) img = Image.fromarray(frame).convert('RGB') if writer is None: writer = cv2.VideoWriter(tmpname, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size) processed_frames += 1 print(f"Processing frame {processed_frames}") progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}") out = process(np.array(img)) writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB)) cap.release() writer.release() return tmpname title = "🎞️ Video Background Removal Tool 🎥" description = """Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode.""" examples = [['./input.mp4']] iface = gr.Interface( fn=process_video, inputs=["video"], outputs="video", examples=examples, title=title, description=description ) iface.launch()