r2flow / app.py
Kazuto Nakashima
init
5acffd4
import re
import einops
import gradio as gr
import matplotlib.cm as cm
import numpy as np
import plotly.graph_objects as go
import torch
import torch.nn.functional as F
import torchdiffeq
DESCRIPTION = """
<div class="head">
<div class="title">Fast LiDAR Data Generation with Rectified Flows</div>
<div class="conference">ICRA 2025</div>
<div class="authors">
<a href="https://kazuto1011.github.io/" target="_blank" rel="noopener"> Kazuto Nakashima</a><sup>1</sup>
&nbsp;&nbsp;&nbsp;
<a> Xiaowen Liu</a><sup>1</sup>
&nbsp;&nbsp;&nbsp;
<a> Tomoya Miyawaki</a><sup>1</sup>
&nbsp;&nbsp;&nbsp;
<a> Yumi Iwashita</a><sup>2</sup>
&nbsp;&nbsp;&nbsp;
<a> Ryo Kurazume</a><sup>1</sup>
</div>
<div class="affiliations">
<sup>1</sup>Kyushu University
&nbsp;&nbsp;&nbsp;
<sup>2</sup>NASA Jet Propulsion Laboratory
</div>
<div class="materials">
<a href="https://kazuto1011.github.io/r2flow">Project</a> |
<a href="https://arxiv.org/abs/2412.02241">Paper</a> |
<a href="https://github.com/kazuto1011/r2flow">Code</a>
</div>
<br>
<div class="description">
This is a demo of our paper "Fast LiDAR Data Generation with Rectified Flows" accepted to ICRA 2025.<br>
We propose <strong>R2Flow</strong>, a rectified flow-based LiDAR generative model which generate the LiDAR range/reflectance images.<br>
</div>
<br>
</div>
"""
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
device = torch.device(device)
model_dict = {
"1-RF": "r2flow-kitti360-1rf",
"2-RF": "r2flow-kitti360-2rf",
"2-RF + 4-TD": "r2flow-kitti360-2rf-4td",
"2-RF + 2-TD": "r2flow-kitti360-2rf-2td",
"2-RF + 1-TD": "r2flow-kitti360-2rf-1td",
}
torch_hub_kwargs = dict(
repo_or_dir="kazuto1011/r2flow",
model="pretrained_r2flow",
device=device,
show_info=False,
)
def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo):
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
colors = torch.from_numpy(colors).to(tensor)
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
ids = (tensor * 256).clamp(0, 255).long()
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2)
tensor = tensor.mul(255).clamp(0, 255).byte()
return tensor
def model_verbose(model, nfe, progress):
handler = progress.tqdm(range(nfe), desc="Generating...")
def _model(t, x):
handler.update(1)
return model(t, x)
return _model
def generate(nfe: int, solver: str, phase: str, progress=gr.Progress()):
model, lidar_utils, _ = torch.hub.load(config=model_dict[phase], **torch_hub_kwargs)
with torch.inference_mode():
x1 = torchdiffeq.odeint(
func=model_verbose(model, int(nfe), progress),
y0=torch.randn(1, model.in_channels, *model.resolution, device=device),
t=torch.linspace(0, 1, int(nfe) + 1, device=device),
method=solver,
)[-1]
depth = lidar_utils.restore_metric_depth(x1[:, [0]])
rflct = lidar_utils.denormalize(x1[:, [1]])
point = lidar_utils.convert_metric_depth(depth, format="cartesian")
z_min, z_max = -2, 0.5
z = (point[:, [2]] - z_min) / (z_max - z_min)
color = colorize(z.clamp(0, 1), cm.viridis) / 255
point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy()
color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy()
fig = go.Figure(
data=[
go.Scatter3d(
x=-point[..., 0],
y=-point[..., 1],
z=point[..., 2],
mode="markers",
marker=dict(size=1, color=color),
)
],
layout=dict(
scene=dict(
xaxis=dict(showticklabels=False, visible=False),
yaxis=dict(showticklabels=False, visible=False),
zaxis=dict(showticklabels=False, visible=False),
aspectmode="data",
),
margin=dict(l=0, r=0, b=0, t=0),
paper_bgcolor="white",
plot_bgcolor="white",
),
)
depth = depth / lidar_utils.max_depth
depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
model.cpu()
lidar_utils.cpu()
return depth, rflct, fig
def setup_dropdown(value):
if "TD" in value:
solver_choices = ["euler"]
solver_default = "euler"
num_step = re.findall(r"(\d+)-TD", value)[0]
nfe_choices = [num_step]
nfe_default = num_step
else:
solver_choices = ["euler", "dopri5"]
solver_default = "euler"
nfe_choices = [2**i for i in range(0, 9)]
nfe_default = 256
dropdown_solver = gr.Dropdown(
choices=solver_choices,
value=solver_default,
label="ODE solver",
info="Fixed if TD enabled",
)
dropdown_nfe = gr.Dropdown(
choices=nfe_choices,
value=nfe_default,
label="Number of sampling steps",
info="Fixed if TD enabled",
)
return dropdown_solver, dropdown_nfe
with gr.Blocks(
css="""
.head {
text-align: center;
display: block;
font-size: var(--text-xl);
}
.title {
font-size: var(--text-xxl);
font-weight: bold;
margin-top: 2rem;
}
.description {
font-size: var(--text-lg);
}
""",
theme=gr.themes.Ocean(),
) as demo:
gr.HTML(DESCRIPTION)
with gr.Row(variant="panel"):
with gr.Column():
gr.Textbox(device, label="Running device")
dropdown_model = gr.Dropdown(
choices=list(model_dict.keys()),
value="2-RF + 4-TD",
label="Model checkpoint",
info="RF: rectified flow, TD: timestep distillation",
)
dropdown_solver, dropdown_nfe = setup_dropdown(dropdown_model.value)
dropdown_model.change(
setup_dropdown,
inputs=[dropdown_model],
outputs=[dropdown_solver, dropdown_nfe],
)
btn = gr.Button(value="Generate", variant="primary")
with gr.Column():
range_view = gr.Image(type="numpy", label="Range image")
rflct_view = gr.Image(type="numpy", label="Reflectance image")
point_view = gr.Plot(label="Point cloud")
btn.click(
generate,
inputs=[dropdown_nfe, dropdown_solver, dropdown_model],
outputs=[range_view, rflct_view, point_view],
)
demo.queue()
demo.launch()