import rerun as rr import rerun.blueprint as rrb import depth_pro import subprocess import torch import os import gradio as gr from gradio_rerun import Rerun import spaces from PIL import Image import tempfile import cv2 # Run the script to get pretrained models if not os.path.exists("./checkpoints/depth_pro.pt"): print("downloading pretrained model") subprocess.run(["bash", "get_pretrained_models.sh"]) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load model and preprocessing transform print("loading model...") model, transform = depth_pro.create_model_and_transforms() model = model.to(device) model.eval() def resize_image(image_buffer, max_size=256): with Image.fromarray(image_buffer) as img: # Calculate the new size while maintaining aspect ratio ratio = max_size / max(img.size) new_size = tuple([int(x * ratio) for x in img.size]) # Resize the image img = img.resize(new_size, Image.LANCZOS) # Create a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: img.save(temp_file, format="PNG") return temp_file.name @spaces.GPU(duration=20) def predict_depth(input_images): results = [depth_pro.load_rgb(image) for image in input_images] images = torch.stack([transform(result[0]) for result in results]) images = images.to(device) # Run inference with torch.no_grad(): prediction = model.infer(images) depth = prediction["depth"] # Depth in [m] focallength_px = prediction["focallength_px"] # Focal length in pixels # Convert depth to numpy array if it's a torch tensor if isinstance(depth, torch.Tensor): depth = depth.cpu().numpy() # Convert focal length to a float if it's a torch tensor if isinstance(focallength_px, torch.Tensor): focallength_px = [focal_length.item() for focal_length in focallength_px] # Ensure depth is a BxHxW tensor if depth.ndim != 2: depth = depth.squeeze() # Clip depth values to 0m - 10m depth = depth.clip(0, 10) return depth, focallength_px @rr.thread_local_stream("rerun_example_ml_depth_pro") def run_rerun(path_to_video): print("video path:", path_to_video) stream = rr.binary_stream() blueprint = rrb.Blueprint( rrb.Vertical( rrb.Spatial3DView(origin="/"), rrb.Horizontal( rrb.Spatial2DView( origin="/world/camera/depth", ), rrb.Spatial2DView(origin="/world/camera/frame"), ), ), collapse_panels=True, ) rr.send_blueprint(blueprint) yield stream.read() video_asset = rr.AssetVideo(path=path_to_video) rr.log("world/video", video_asset, static=True) # Send automatically determined video frame timestamps. frame_timestamps_ns = video_asset.read_frame_timestamps_ns() cap = cv2.VideoCapture(path_to_video) num_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) fps_video = cap.get(cv2.CAP_PROP_FPS) # limit the number of frames to 10 seconds of video max_frames = min(10 * fps_video, num_frames) free_vram, _ = torch.cuda.mem_get_info(device) free_vram = free_vram / 1024 / 1024 / 1024 # batch size is determined by the amount of free vram batch_size = int(min(min(4, free_vram // 4), max_frames)) # go through all the frames in the video, using the batch size for i in range(0, int(max_frames), batch_size): if i >= max_frames: raise gr.Error("Reached the maximum number of frames to process") frames = [] frame_indices = list(range(i, min(i + batch_size, int(max_frames)))) for _ in range(batch_size): ret, frame = cap.read() if not ret: break frames.append(frame) temp_files = [] try: # Resize the images to make the inference faster temp_files = [resize_image(frame, max_size=256) for frame in frames] depths, focal_lengths = predict_depth(temp_files) for depth, focal_length, frame_idx in zip( depths, focal_lengths, frame_indices ): # find x and y scale factors, which can be applied to image x_scale = depth.shape[1] / frames[0].shape[1] y_scale = depth.shape[0] / frames[0].shape[0] rr.set_time_nanos("video_time", frame_timestamps_ns[frame_idx]) rr.log( "world/camera/depth", rr.DepthImage(depth, meter=1), ) rr.log( "world/camera/frame", rr.VideoFrameReference( timestamp=rr.components.VideoTimestamp( nanoseconds=frame_timestamps_ns[frame_idx] ), video_reference="world/video", ), rr.Transform3D(scale=(x_scale, y_scale, 1)), ) rr.log( "world/camera", rr.Pinhole( focal_length=focal_length, width=depth.shape[1], height=depth.shape[0], principal_point=(depth.shape[1] / 2, depth.shape[0] / 2), camera_xyz=rr.ViewCoordinates.FLU, image_plane_distance=depth.max(), ), ) yield stream.read() except Exception as e: raise gr.Error(f"An error has occurred: {e}") finally: # Clean up the temporary files for temp_file in temp_files: if temp_file and os.path.exists(temp_file): os.remove(temp_file) yield stream.read() with gr.Blocks() as interface: gr.Markdown( """ # DepthPro Rerun Demo [DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload a video to visualize the depth predictions in real-time. High resolution videos will be automatically resized to 256x256 pixels, to speed up the inference and visualize multiple frames. """ ) with gr.Row(): with gr.Column(variant="compact"): video = gr.Video( format="mp4", interactive=True, label="Video", include_audio=False ) visualize = gr.Button("Visualize ML Depth Pro") with gr.Column(): viewer = Rerun( streaming=True, ) visualize.click(run_rerun, inputs=[video], outputs=[viewer]) if __name__ == "__main__": interface.launch()