import os import torch import numpy as np import gradio as gr from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM from sam2.build_sam import build_sam2_video_predictor, build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor import cv2 import traceback import matplotlib.pyplot as plt import ffmpeg from utils import load_model_without_flash_attn # CUDA optimizations torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Initialize models sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" model_cfg = "sam2_hiera_l.yaml" video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") image_predictor = SAM2ImagePredictor(sam2_model) model_id = 'microsoft/Florence-2-large' device = "cuda" if torch.cuda.is_available() else "cpu" def load_florence_model(): return AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).eval().to(device) florence_model = load_model_without_flash_attn(load_florence_model) florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) def apply_color_mask(frame, mask, obj_id): cmap = plt.get_cmap("tab10") color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors # Ensure mask has the correct shape if mask.ndim == 4: mask = mask.squeeze() # Remove singleton dimensions if mask.ndim == 3 and mask.shape[0] == 1: mask = mask[0] # Take the first channel if it's a single-channel 3D array # Reshape mask to match frame dimensions mask = cv2.resize(mask.astype(np.float32), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR) # Expand dimensions of mask and color for broadcasting mask = np.expand_dims(mask, axis=2) color = color.reshape(1, 1, 3) colored_mask = mask * color return frame * (1 - mask) + colored_mask * 255 def run_florence(image, text_input): with torch.amp.autocast(dtype=torch.bfloat16): task_prompt = '' prompt = task_prompt + text_input inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.bfloat16) generated_ids = florence_model.generate( input_ids=inputs["input_ids"].cuda(), pixel_values=inputs["pixel_values"].cuda(), max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = florence_processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height) ) return parsed_answer[task_prompt]['bboxes'][0] def remove_directory_contents(directory): for root, dirs, files in os.walk(directory, topdown=False): for name in files: os.remove(os.path.join(root, name)) for name in dirs: os.rmdir(os.path.join(root, name)) def process_video(video_path, prompt): try: # Get video info probe = ffmpeg.probe(video_path) video_info = next(s for s in probe['streams'] if s['codec_type'] == 'video') width = int(video_info['width']) height = int(video_info['height']) num_frames = int(video_info['nb_frames']) fps = eval(video_info['r_frame_rate']) print(f"Video info: {width}x{height}, {num_frames} frames, {fps} fps") # Read frames out, _ = ( ffmpeg .input(video_path) .output('pipe:', format='rawvideo', pix_fmt='rgb24') .run(capture_stdout=True) ) frames = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3]) print(f"Read {len(frames)} frames") # Florence detection on first frame first_frame = Image.fromarray(frames[0]) mask_box = run_florence(first_frame, prompt) print("Original mask box:", mask_box) # Convert mask_box to numpy array mask_box = np.array(mask_box) print("Reshaped mask box:", mask_box) # SAM2 segmentation on first frame with torch.cuda.amp.autocast(dtype=torch.bfloat16): image_predictor.set_image(first_frame) masks, _, _ = image_predictor.predict( point_coords=None, point_labels=None, box=mask_box[None, :], multimask_output=False, ) print("masks.shape", masks.shape) mask = masks.squeeze().astype(bool) print("Mask shape:", mask.shape) print("Frame shape:", frames[0].shape) # SAM2 video propagation temp_dir = "temp_frames" os.makedirs(temp_dir, exist_ok=True) for i, frame in enumerate(frames): Image.fromarray(frame).save(os.path.join(temp_dir, f"{i:04d}.jpg")) print(f"Saved {len(frames)} temporary frames") with torch.cuda.amp.autocast(dtype=torch.bfloat16): inference_state = video_predictor.init_state(video_path=temp_dir) _, _, _ = video_predictor.add_new_mask( inference_state=inference_state, frame_idx=0, obj_id=1, mask=mask ) video_segments = {} for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } print('Segmenting for main vid done') print(f"Number of segmented frames: {len(video_segments)}") # Apply segmentation masks to frames all_segmented_frames = [] for i, frame in enumerate(frames): if i in video_segments: for out_obj_id, mask in video_segments[i].items(): frame = apply_color_mask(frame, mask, out_obj_id) all_segmented_frames.append(frame.astype(np.uint8)) else: all_segmented_frames.append(frame) print(f"Applied masks to {len(all_segmented_frames)} frames") # Clean up temporary files remove_directory_contents(temp_dir) os.rmdir(temp_dir) # Write output video using ffmpeg output_path = "segmented_video.mp4" process = ( ffmpeg .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{width}x{height}', r=fps) .output(output_path, pix_fmt='yuv420p') .overwrite_output() .run_async(pipe_stdin=True) ) for frame in all_segmented_frames: process.stdin.write(frame.tobytes()) process.stdin.close() process.wait() if not os.path.exists(output_path): raise ValueError(f"Output video file was not created: {output_path}") print(f"Successfully created output video: {output_path}") return output_path except Exception as e: print(f"Error in process_video: {str(e)}") print(traceback.format_exc()) # This will print the full stack trace return None def segment_video(video_file, prompt): if video_file is None: return None output_video = process_video(video_file, prompt) return output_video demo = gr.Interface( fn=segment_video, inputs=[ gr.Video(label="Upload Video"), gr.Textbox(label="Enter prompt (e.g., 'a gymnast')") ], outputs=gr.Video(label="Segmented Video"), title="Video Object Segmentation with Florence and SAM2", description="Upload a video and provide a text prompt to segment a specific object throughout the video." ) demo.launch()