r2dm / app.py
Kazuto Nakashima
update
ffb8a4e
raw
history blame
4.86 kB
import gradio as gr
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import einops
import plotly.graph_objects as go
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ddpm, lidar_utils, _ = torch.hub.load(
"kazuto1011/r2dm", "pretrained_r2dm", device=device
)
def colorize(tensor, cmap_fn=cm.turbo):
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
colors = torch.from_numpy(colors).to(tensor)
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
ids = (tensor * 256).clamp(0, 255).long()
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2)
tensor = tensor.mul(255).clamp(0, 255).byte()
return tensor
def render_point_cloud(output, cmap):
output = lidar_utils.denormalize(output.clamp(-1, 1))
depth = lidar_utils.revert_depth(output[:, [0]])
rflct = output[:, [1]]
point = lidar_utils.to_xyz(depth).cpu().numpy()
point = einops.rearrange(point, "1 c h w -> c (h w)")
angle = lidar_utils.ray_angles.rad2deg()
label = [
f"depth: {float(d):.2f}m<br>"
+ f"reflectance: {float(r):.2f}<br>"
+ f"elevation: {float(e):.2f}°<br>"
+ f"azimuth: {float(a):.2f}°"
for d, r, e, a in zip(
einops.rearrange(depth, "1 1 h w -> (h w)"),
einops.rearrange(rflct, "1 1 h w -> (h w)"),
einops.rearrange(angle[0, 0], "h w -> (h w)"),
einops.rearrange(angle[0, 1], "h w -> (h w)"),
)
]
fig = go.Figure(
data=[
go.Scatter3d(
x=-point[0],
y=-point[1],
z=point[2],
mode="markers",
marker=dict(
size=1,
color=point[2],
colorscale="viridis",
autocolorscale=False,
cauto=False,
cmin=-2,
cmax=0.5,
),
text=label,
hoverinfo="text",
)
],
layout=dict(
scene=dict(
xaxis=dict(showticklabels=False, visible=False),
yaxis=dict(showticklabels=False, visible=False),
zaxis=dict(showticklabels=False, visible=False),
aspectmode="data",
),
margin=dict(l=0, r=0, b=0, t=0),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
),
)
depth = depth / lidar_utils.max_depth
depth = colorize(depth, cmap)[0].permute(1, 2, 0).cpu().numpy()
rflct = colorize(rflct, cmap)[0].permute(1, 2, 0).cpu().numpy()
return depth, rflct, fig
def generate(num_steps, cmap_name, progress=gr.Progress()):
num_steps = int(num_steps)
x = ddpm.randn(1, *ddpm.sampling_shape, device=ddpm.device)
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=ddpm.device)[None]
for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
step_t = steps[:, i]
step_s = steps[:, i + 1]
x = ddpm.p_sample(x, step_t, step_s)
return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name))
with gr.Blocks() as demo:
gr.Markdown(
"""
# R2DM
R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation.
> **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
Kazuto Nakashima, Ryo Kurazume<br>
[[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
"""
)
with gr.Row():
with gr.Column():
num_steps = gr.Dropdown(
choices=[2**i for i in range(2, 10)],
value=16,
label="number of sampling steps (>256 is recommended)",
)
cmap_name = gr.Dropdown(
choices=plt.colormaps(),
value="turbo",
label="colormap for range/reflectance images",
)
btn = gr.Button(value="Generate random samples")
with gr.Column():
range_view = gr.Image(
type="numpy",
image_mode="RGB",
label="Range image",
scale=1,
)
rflct_view = gr.Image(
type="numpy",
image_mode="RGB",
label="Reflectance image",
scale=1,
)
point_view = gr.Plot(
label="Point cloud",
scale=1,
)
btn.click(
generate,
inputs=[num_steps, cmap_name],
outputs=[range_view, rflct_view, point_view],
)
demo.queue()
demo.launch()