Spaces:
Running
Running
import re | |
import einops | |
import gradio as gr | |
import matplotlib.cm as cm | |
import numpy as np | |
import plotly.graph_objects as go | |
import torch | |
import torch.nn.functional as F | |
import torchdiffeq | |
DESCRIPTION = """ | |
<div class="head"> | |
<div class="title">Fast LiDAR Data Generation with Rectified Flows</div> | |
<div class="conference">ICRA 2025</div> | |
<div class="authors"> | |
<a href="https://kazuto1011.github.io/" target="_blank" rel="noopener"> Kazuto Nakashima</a><sup>1</sup> | |
| |
<a> Xiaowen Liu</a><sup>1</sup> | |
| |
<a> Tomoya Miyawaki</a><sup>1</sup> | |
| |
<a> Yumi Iwashita</a><sup>2</sup> | |
| |
<a> Ryo Kurazume</a><sup>1</sup> | |
</div> | |
<div class="affiliations"> | |
<sup>1</sup>Kyushu University | |
| |
<sup>2</sup>NASA Jet Propulsion Laboratory | |
</div> | |
<div class="materials"> | |
<a href="https://kazuto1011.github.io/r2flow">Project</a> | | |
<a href="https://arxiv.org/abs/2412.02241">Paper</a> | | |
<a href="https://github.com/kazuto1011/r2flow">Code</a> | |
</div> | |
<br> | |
<div class="description"> | |
This is a demo of our paper "Fast LiDAR Data Generation with Rectified Flows" accepted to ICRA 2025.<br> | |
We propose <strong>R2Flow</strong>, a rectified flow-based LiDAR generative model which generate the LiDAR range/reflectance images.<br> | |
</div> | |
<br> | |
</div> | |
""" | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
torch.set_grad_enabled(False) | |
torch.backends.cudnn.benchmark = True | |
device = torch.device(device) | |
model_dict = { | |
"1-RF": "r2flow-kitti360-1rf", | |
"2-RF": "r2flow-kitti360-2rf", | |
"2-RF + 4-TD": "r2flow-kitti360-2rf-4td", | |
"2-RF + 2-TD": "r2flow-kitti360-2rf-2td", | |
"2-RF + 1-TD": "r2flow-kitti360-2rf-1td", | |
} | |
torch_hub_kwargs = dict( | |
repo_or_dir="kazuto1011/r2flow", | |
model="pretrained_r2flow", | |
device=device, | |
show_info=False, | |
) | |
def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo): | |
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3] | |
colors = torch.from_numpy(colors).to(tensor) | |
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor | |
ids = (tensor * 256).clamp(0, 255).long() | |
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2) | |
tensor = tensor.mul(255).clamp(0, 255).byte() | |
return tensor | |
def model_verbose(model, nfe, progress): | |
handler = progress.tqdm(range(nfe), desc="Generating...") | |
def _model(t, x): | |
handler.update(1) | |
return model(t, x) | |
return _model | |
def generate(nfe: int, solver: str, phase: str, progress=gr.Progress()): | |
model, lidar_utils, _ = torch.hub.load(config=model_dict[phase], **torch_hub_kwargs) | |
with torch.inference_mode(): | |
x1 = torchdiffeq.odeint( | |
func=model_verbose(model, int(nfe), progress), | |
y0=torch.randn(1, model.in_channels, *model.resolution, device=device), | |
t=torch.linspace(0, 1, int(nfe) + 1, device=device), | |
method=solver, | |
)[-1] | |
depth = lidar_utils.restore_metric_depth(x1[:, [0]]) | |
rflct = lidar_utils.denormalize(x1[:, [1]]) | |
point = lidar_utils.convert_metric_depth(depth, format="cartesian") | |
z_min, z_max = -2, 0.5 | |
z = (point[:, [2]] - z_min) / (z_max - z_min) | |
color = colorize(z.clamp(0, 1), cm.viridis) / 255 | |
point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy() | |
color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy() | |
fig = go.Figure( | |
data=[ | |
go.Scatter3d( | |
x=-point[..., 0], | |
y=-point[..., 1], | |
z=point[..., 2], | |
mode="markers", | |
marker=dict(size=1, color=color), | |
) | |
], | |
layout=dict( | |
scene=dict( | |
xaxis=dict(showticklabels=False, visible=False), | |
yaxis=dict(showticklabels=False, visible=False), | |
zaxis=dict(showticklabels=False, visible=False), | |
aspectmode="data", | |
), | |
margin=dict(l=0, r=0, b=0, t=0), | |
paper_bgcolor="white", | |
plot_bgcolor="white", | |
), | |
) | |
depth = depth / lidar_utils.max_depth | |
depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() | |
rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() | |
model.cpu() | |
lidar_utils.cpu() | |
return depth, rflct, fig | |
def setup_dropdown(value): | |
if "TD" in value: | |
solver_choices = ["euler"] | |
solver_default = "euler" | |
num_step = re.findall(r"(\d+)-TD", value)[0] | |
nfe_choices = [num_step] | |
nfe_default = num_step | |
else: | |
solver_choices = ["euler", "dopri5"] | |
solver_default = "euler" | |
nfe_choices = [2**i for i in range(0, 9)] | |
nfe_default = 256 | |
dropdown_solver = gr.Dropdown( | |
choices=solver_choices, | |
value=solver_default, | |
label="ODE solver", | |
info="Fixed if TD enabled", | |
) | |
dropdown_nfe = gr.Dropdown( | |
choices=nfe_choices, | |
value=nfe_default, | |
label="Number of sampling steps", | |
info="Fixed if TD enabled", | |
) | |
return dropdown_solver, dropdown_nfe | |
with gr.Blocks( | |
css=""" | |
.head { | |
text-align: center; | |
display: block; | |
font-size: var(--text-xl); | |
} | |
.title { | |
font-size: var(--text-xxl); | |
font-weight: bold; | |
margin-top: 2rem; | |
} | |
.description { | |
font-size: var(--text-lg); | |
} | |
""", | |
theme=gr.themes.Ocean(), | |
) as demo: | |
gr.HTML(DESCRIPTION) | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
gr.Textbox(device, label="Running device") | |
dropdown_model = gr.Dropdown( | |
choices=list(model_dict.keys()), | |
value="2-RF + 4-TD", | |
label="Model checkpoint", | |
info="RF: rectified flow, TD: timestep distillation", | |
) | |
dropdown_solver, dropdown_nfe = setup_dropdown(dropdown_model.value) | |
dropdown_model.change( | |
setup_dropdown, | |
inputs=[dropdown_model], | |
outputs=[dropdown_solver, dropdown_nfe], | |
) | |
btn = gr.Button(value="Generate", variant="primary") | |
with gr.Column(): | |
range_view = gr.Image(type="numpy", label="Range image") | |
rflct_view = gr.Image(type="numpy", label="Reflectance image") | |
point_view = gr.Plot(label="Point cloud") | |
btn.click( | |
generate, | |
inputs=[dropdown_nfe, dropdown_solver, dropdown_model], | |
outputs=[range_view, rflct_view, point_view], | |
) | |
demo.queue() | |
demo.launch() | |