import re import einops import gradio as gr import matplotlib.cm as cm import numpy as np import plotly.graph_objects as go import torch import torch.nn.functional as F import torchdiffeq DESCRIPTION = """
Fast LiDAR Data Generation with Rectified Flows
ICRA 2025
Kazuto Nakashima1     Xiaowen Liu1     Tomoya Miyawaki1     Yumi Iwashita2     Ryo Kurazume1
1Kyushu University     2NASA Jet Propulsion Laboratory
Project | Paper | Code

This is a demo of our paper "Fast LiDAR Data Generation with Rectified Flows" accepted to ICRA 2025.
We propose R2Flow, a rectified flow-based LiDAR generative model which generate the LiDAR range/reflectance images.

""" if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" torch.set_grad_enabled(False) torch.backends.cudnn.benchmark = True device = torch.device(device) model_dict = { "1-RF": "r2flow-kitti360-1rf", "2-RF": "r2flow-kitti360-2rf", "2-RF + 4-TD": "r2flow-kitti360-2rf-4td", "2-RF + 2-TD": "r2flow-kitti360-2rf-2td", "2-RF + 1-TD": "r2flow-kitti360-2rf-1td", } torch_hub_kwargs = dict( repo_or_dir="kazuto1011/r2flow", model="pretrained_r2flow", device=device, show_info=False, ) def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo): colors = cmap_fn(np.linspace(0, 1, 256))[:, :3] colors = torch.from_numpy(colors).to(tensor) tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor ids = (tensor * 256).clamp(0, 255).long() tensor = F.embedding(ids, colors).permute(0, 3, 1, 2) tensor = tensor.mul(255).clamp(0, 255).byte() return tensor def model_verbose(model, nfe, progress): handler = progress.tqdm(range(nfe), desc="Generating...") def _model(t, x): handler.update(1) return model(t, x) return _model def generate(nfe: int, solver: str, phase: str, progress=gr.Progress()): model, lidar_utils, _ = torch.hub.load(config=model_dict[phase], **torch_hub_kwargs) with torch.inference_mode(): x1 = torchdiffeq.odeint( func=model_verbose(model, int(nfe), progress), y0=torch.randn(1, model.in_channels, *model.resolution, device=device), t=torch.linspace(0, 1, int(nfe) + 1, device=device), method=solver, )[-1] depth = lidar_utils.restore_metric_depth(x1[:, [0]]) rflct = lidar_utils.denormalize(x1[:, [1]]) point = lidar_utils.convert_metric_depth(depth, format="cartesian") z_min, z_max = -2, 0.5 z = (point[:, [2]] - z_min) / (z_max - z_min) color = colorize(z.clamp(0, 1), cm.viridis) / 255 point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy() color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy() fig = go.Figure( data=[ go.Scatter3d( x=-point[..., 0], y=-point[..., 1], z=point[..., 2], mode="markers", marker=dict(size=1, color=color), ) ], layout=dict( scene=dict( xaxis=dict(showticklabels=False, visible=False), yaxis=dict(showticklabels=False, visible=False), zaxis=dict(showticklabels=False, visible=False), aspectmode="data", ), margin=dict(l=0, r=0, b=0, t=0), paper_bgcolor="white", plot_bgcolor="white", ), ) depth = depth / lidar_utils.max_depth depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() model.cpu() lidar_utils.cpu() return depth, rflct, fig def setup_dropdown(value): if "TD" in value: solver_choices = ["euler"] solver_default = "euler" num_step = re.findall(r"(\d+)-TD", value)[0] nfe_choices = [num_step] nfe_default = num_step else: solver_choices = ["euler", "dopri5"] solver_default = "euler" nfe_choices = [2**i for i in range(0, 9)] nfe_default = 256 dropdown_solver = gr.Dropdown( choices=solver_choices, value=solver_default, label="ODE solver", info="Fixed if TD enabled", ) dropdown_nfe = gr.Dropdown( choices=nfe_choices, value=nfe_default, label="Number of sampling steps", info="Fixed if TD enabled", ) return dropdown_solver, dropdown_nfe with gr.Blocks( css=""" .head { text-align: center; display: block; font-size: var(--text-xl); } .title { font-size: var(--text-xxl); font-weight: bold; margin-top: 2rem; } .description { font-size: var(--text-lg); } """, theme=gr.themes.Ocean(), ) as demo: gr.HTML(DESCRIPTION) with gr.Row(variant="panel"): with gr.Column(): gr.Textbox(device, label="Running device") dropdown_model = gr.Dropdown( choices=list(model_dict.keys()), value="2-RF + 4-TD", label="Model checkpoint", info="RF: rectified flow, TD: timestep distillation", ) dropdown_solver, dropdown_nfe = setup_dropdown(dropdown_model.value) dropdown_model.change( setup_dropdown, inputs=[dropdown_model], outputs=[dropdown_solver, dropdown_nfe], ) btn = gr.Button(value="Generate", variant="primary") with gr.Column(): range_view = gr.Image(type="numpy", label="Range image") rflct_view = gr.Image(type="numpy", label="Reflectance image") point_view = gr.Plot(label="Point cloud") btn.click( generate, inputs=[dropdown_nfe, dropdown_solver, dropdown_model], outputs=[range_view, rflct_view, point_view], ) demo.queue() demo.launch()