import sys sys.path.append("../../") import os import json import time import psutil import argparse import cv2 import torch import torchvision import numpy as np import gradio as gr from tools.painter import mask_painter from track_anything import TrackingAnything from model.misc import get_device from utils.download_util import load_file_from_url def parse_augment(): parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default=None) parser.add_argument('--sam_model_type', type=str, default="vit_h") parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications") parser.add_argument('--mask_save', default=False) args = parser.parse_args() if not args.device: args.device = str(get_device()) return args # convert points input to prompt state def get_prompt(click_state, click_input): inputs = json.loads(click_input) points = click_state[0] labels = click_state[1] for input in inputs: points.append(input[:2]) labels.append(input[2]) click_state[0] = points click_state[1] = labels prompt = { "prompt_type":["click"], "input_point":click_state[0], "input_label":click_state[1], "multimask_output":"True", } return prompt # extract frames from upload video def get_frames_from_video(video_input, video_state): """ Args: video_path:str timestamp:float64 Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ video_path = video_input frames = [] user_name = time.time() operation_log = [("",""),("Video uploaded! Try to click the image shown in step2 to add masks.","Normal")] try: cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) while cap.isOpened(): ret, frame = cap.read() if ret == True: current_memory_usage = psutil.virtual_memory().percent frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # if current_memory_usage > 90: # operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")] # print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.") # break else: break except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: print("read_frame_source:{} error. {}\n".format(video_path, str(e))) image_size = (frames[0].shape[0],frames[0].shape[1]) # initialize video_state video_state = { "user_name": user_name, "video_name": os.path.split(video_path)[-1], "origin_images": frames, "painted_images": frames.copy(), "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), "logits": [None]*len(frames), "select_frame_number": 0, "fps": fps } video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size) model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True, choices=[], value=[]), \ gr.update(visible=True, value=operation_log), gr.update(visible=True, value=operation_log) # get the select frame from gradio slider def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown): # images = video_state[1] image_selection_slider -= 1 video_state["select_frame_number"] = image_selection_slider # once select a new template frame, set the image in sam model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) operation_log = [("",""), ("Select tracking start frame {}. Try to click the image to add masks for tracking.".format(image_selection_slider),"Normal")] return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log, operation_log # set the tracking end frame def get_end_number(track_pause_number_slider, video_state, interactive_state): interactive_state["track_end_number"] = track_pause_number_slider operation_log = [("",""),("Select tracking finish frame {}.Try to click the image to add masks for tracking.".format(track_pause_number_slider),"Normal")] return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log # use sam to get the mask def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData): """ Args: template_frame: PIL.Image point_prompt: flag for positive or negative button click click_state: [[points], [labels]] """ if point_prompt == "Positive": coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) interactive_state["positive_click_times"] += 1 else: coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) interactive_state["negative_click_times"] += 1 # prompt for sam model model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) prompt = get_prompt(click_state=click_state, click_input=coordinate) mask, logit, painted_image = model.first_frame_click( image=video_state["origin_images"][video_state["select_frame_number"]], points=np.array(prompt["input_point"]), labels=np.array(prompt["input_label"]), multimask=prompt["multimask_output"], ) video_state["masks"][video_state["select_frame_number"]] = mask video_state["logits"][video_state["select_frame_number"]] = logit video_state["painted_images"][video_state["select_frame_number"]] = painted_image operation_log = [("",""), ("You can try to add positive or negative points by clicking, click Clear clicks button to refresh the image, click Add mask button when you are satisfied with the segment, or click Remove mask button to remove all added masks.","Normal")] return painted_image, video_state, interactive_state, operation_log, operation_log def add_multi_mask(video_state, interactive_state, mask_dropdown): try: mask = video_state["masks"][video_state["select_frame_number"]] interactive_state["multi_mask"]["masks"].append(mask) interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown) operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")] except: operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")] return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log def clear_click(video_state, click_state): click_state = [[],[]] template_frame = video_state["origin_images"][video_state["select_frame_number"]] operation_log = [("",""), ("Cleared points history and refresh the image.","Normal")] return template_frame, click_state, operation_log, operation_log def remove_multi_mask(interactive_state, mask_dropdown): interactive_state["multi_mask"]["mask_names"]= [] interactive_state["multi_mask"]["masks"] = [] operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")] return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log def show_mask(video_state, interactive_state, mask_dropdown): mask_dropdown.sort() select_frame = video_state["origin_images"][video_state["select_frame_number"]] for i in range(len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 mask = interactive_state["multi_mask"]["masks"][mask_number] select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")] return select_frame, operation_log, operation_log # tracking vos def vos_tracking_video(video_state, interactive_state, mask_dropdown): operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")] model.cutie.clear_memory() if interactive_state["track_end_number"]: following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] else: following_frames = video_state["origin_images"][video_state["select_frame_number"]:] if interactive_state["multi_mask"]["masks"]: if len(mask_dropdown) == 0: mask_dropdown = ["mask_001"] mask_dropdown.sort() template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) for i in range(1,len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) video_state["masks"][video_state["select_frame_number"]]= template_mask else: template_mask = video_state["masks"][video_state["select_frame_number"]] fps = video_state["fps"] # operation error if len(np.unique(template_mask))==1: template_mask[0][0]=1 operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")] # return video_output, video_state, interactive_state, operation_error masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) # clear GPU memory model.cutie.clear_memory() if interactive_state["track_end_number"]: video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images else: video_state["masks"][video_state["select_frame_number"]:] = masks video_state["logits"][video_state["select_frame_number"]:] = logits video_state["painted_images"][video_state["select_frame_number"]:] = painted_images video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video interactive_state["inference_times"] += 1 print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"], interactive_state["positive_click_times"]+interactive_state["negative_click_times"], interactive_state["positive_click_times"], interactive_state["negative_click_times"])) #### shanggao code for mask save if interactive_state["mask_save"]: if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])): os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0])) i = 0 print("save mask") for mask in video_state["masks"]: np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask) i+=1 # save_mask(video_state["masks"], video_state["video_name"]) #### shanggao code for mask save return video_output, video_state, interactive_state, operation_log, operation_log # inpaint def inpaint_video(video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown): operation_log = [("",""), ("Inpainting finished!","Normal")] frames = np.asarray(video_state["origin_images"]) fps = video_state["fps"] inpaint_masks = np.asarray(video_state["masks"]) if len(mask_dropdown) == 0: mask_dropdown = ["mask_001"] mask_dropdown.sort() # convert mask_dropdown to mask numbers inpaint_mask_numbers = [int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))] # interate through all masks and remove the masks that are not in mask_dropdown unique_masks = np.unique(inpaint_masks) num_masks = len(unique_masks) - 1 for i in range(1, num_masks + 1): if i in inpaint_mask_numbers: continue inpaint_masks[inpaint_masks==i] = 0 # inpaint for videos inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=resize_ratio_number, dilate_radius=dilate_radius_number, raft_iter=raft_iter_number, subvideo_length=subvideo_length_number, neighbor_length=neighbor_length_number, ref_stride=ref_stride_number) # numpy array, T, H, W, 3 video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video return video_output, operation_log, operation_log # generate video after vos inference def generate_video_from_frames(frames, output_path, fps=30): """ Generates a video from a list of frames. Args: frames (list of numpy arrays): The frames to include in the video. output_path (str): The path to save the generated video. fps (int, optional): The frame rate of the output video. Defaults to 30. """ frames = torch.from_numpy(np.asarray(frames)) if not os.path.exists(os.path.dirname(output_path)): os.makedirs(os.path.dirname(output_path)) torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264") return output_path def restart(): operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")] return { "user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 30 }, { "inference_times": 0, "negative_click_times" : 0, "positive_click_times": 0, "mask_save": args.mask_save, "multi_mask": { "mask_names": [], "masks": [] }, "track_end_number": None, }, [[],[]], None, None, None, \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", \ gr.update(visible=True, value=operation_log), gr.update(visible=False, value=operation_log) # args, defined in track_anything.py args = parse_augment() pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/' sam_checkpoint_url_dict = { 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" } checkpoint_fodler = os.path.join('..', '..', 'weights') sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_fodler) cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'cutie-base-mega.pth'), checkpoint_fodler) propainter_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'ProPainter.pth'), checkpoint_fodler) raft_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'raft-things.pth'), checkpoint_fodler) flow_completion_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), checkpoint_fodler) # initialize sam, cutie, propainter models model = TrackingAnything(sam_checkpoint, cutie_checkpoint, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args) title = r"""