Spaces:
Running
Running
Kazuto Nakashima
commited on
Commit
·
ffb8a4e
1
Parent(s):
1ee5104
update
Browse files- app.py +86 -17
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
import gradio as gr
|
2 |
import matplotlib.cm as cm
|
|
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
|
|
|
|
6 |
|
7 |
torch.set_grad_enabled(False)
|
8 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -21,34 +24,95 @@ def colorize(tensor, cmap_fn=cm.turbo):
|
|
21 |
return tensor
|
22 |
|
23 |
|
24 |
-
|
25 |
-
def generate(num_steps) -> str:
|
26 |
-
output = ddpm.sample(batch_size=1, num_steps=int(num_steps))
|
27 |
output = lidar_utils.denormalize(output.clamp(-1, 1))
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
with gr.Blocks() as demo:
|
37 |
gr.Markdown(
|
38 |
"""
|
39 |
-
# R2DM
|
40 |
-
|
|
|
41 |
Kazuto Nakashima, Ryo Kurazume<br>
|
42 |
[[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
|
43 |
"""
|
44 |
)
|
45 |
with gr.Row():
|
46 |
with gr.Column():
|
47 |
-
gr.Text(f"Device: {device}", label="device")
|
48 |
num_steps = gr.Dropdown(
|
49 |
-
choices=[2**i for i in range(
|
50 |
-
value=
|
51 |
-
label="number of sampling steps",
|
|
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
btn = gr.Button(value="Generate random samples")
|
54 |
|
@@ -65,12 +129,17 @@ with gr.Blocks() as demo:
|
|
65 |
label="Reflectance image",
|
66 |
scale=1,
|
67 |
)
|
|
|
|
|
|
|
|
|
68 |
|
69 |
btn.click(
|
70 |
generate,
|
71 |
-
inputs=[num_steps],
|
72 |
-
outputs=[range_view, rflct_view],
|
73 |
)
|
74 |
|
75 |
|
|
|
76 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import matplotlib.cm as cm
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
+
import einops
|
8 |
+
import plotly.graph_objects as go
|
9 |
|
10 |
torch.set_grad_enabled(False)
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
24 |
return tensor
|
25 |
|
26 |
|
27 |
+
def render_point_cloud(output, cmap):
|
|
|
|
|
28 |
output = lidar_utils.denormalize(output.clamp(-1, 1))
|
29 |
+
depth = lidar_utils.revert_depth(output[:, [0]])
|
30 |
+
rflct = output[:, [1]]
|
31 |
+
point = lidar_utils.to_xyz(depth).cpu().numpy()
|
32 |
+
point = einops.rearrange(point, "1 c h w -> c (h w)")
|
33 |
+
angle = lidar_utils.ray_angles.rad2deg()
|
34 |
+
label = [
|
35 |
+
f"depth: {float(d):.2f}m<br>"
|
36 |
+
+ f"reflectance: {float(r):.2f}<br>"
|
37 |
+
+ f"elevation: {float(e):.2f}°<br>"
|
38 |
+
+ f"azimuth: {float(a):.2f}°"
|
39 |
+
for d, r, e, a in zip(
|
40 |
+
einops.rearrange(depth, "1 1 h w -> (h w)"),
|
41 |
+
einops.rearrange(rflct, "1 1 h w -> (h w)"),
|
42 |
+
einops.rearrange(angle[0, 0], "h w -> (h w)"),
|
43 |
+
einops.rearrange(angle[0, 1], "h w -> (h w)"),
|
44 |
+
)
|
45 |
+
]
|
46 |
+
fig = go.Figure(
|
47 |
+
data=[
|
48 |
+
go.Scatter3d(
|
49 |
+
x=-point[0],
|
50 |
+
y=-point[1],
|
51 |
+
z=point[2],
|
52 |
+
mode="markers",
|
53 |
+
marker=dict(
|
54 |
+
size=1,
|
55 |
+
color=point[2],
|
56 |
+
colorscale="viridis",
|
57 |
+
autocolorscale=False,
|
58 |
+
cauto=False,
|
59 |
+
cmin=-2,
|
60 |
+
cmax=0.5,
|
61 |
+
),
|
62 |
+
text=label,
|
63 |
+
hoverinfo="text",
|
64 |
+
)
|
65 |
+
],
|
66 |
+
layout=dict(
|
67 |
+
scene=dict(
|
68 |
+
xaxis=dict(showticklabels=False, visible=False),
|
69 |
+
yaxis=dict(showticklabels=False, visible=False),
|
70 |
+
zaxis=dict(showticklabels=False, visible=False),
|
71 |
+
aspectmode="data",
|
72 |
+
),
|
73 |
+
margin=dict(l=0, r=0, b=0, t=0),
|
74 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
75 |
+
plot_bgcolor="rgba(0,0,0,0)",
|
76 |
+
),
|
77 |
+
)
|
78 |
+
depth = depth / lidar_utils.max_depth
|
79 |
+
depth = colorize(depth, cmap)[0].permute(1, 2, 0).cpu().numpy()
|
80 |
+
rflct = colorize(rflct, cmap)[0].permute(1, 2, 0).cpu().numpy()
|
81 |
+
return depth, rflct, fig
|
82 |
+
|
83 |
+
|
84 |
+
def generate(num_steps, cmap_name, progress=gr.Progress()):
|
85 |
+
num_steps = int(num_steps)
|
86 |
+
x = ddpm.randn(1, *ddpm.sampling_shape, device=ddpm.device)
|
87 |
+
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=ddpm.device)[None]
|
88 |
+
for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
|
89 |
+
step_t = steps[:, i]
|
90 |
+
step_s = steps[:, i + 1]
|
91 |
+
x = ddpm.p_sample(x, step_t, step_s)
|
92 |
+
return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name))
|
93 |
|
94 |
|
95 |
with gr.Blocks() as demo:
|
96 |
gr.Markdown(
|
97 |
"""
|
98 |
+
# R2DM
|
99 |
+
R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation.
|
100 |
+
> **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
|
101 |
Kazuto Nakashima, Ryo Kurazume<br>
|
102 |
[[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
|
103 |
"""
|
104 |
)
|
105 |
with gr.Row():
|
106 |
with gr.Column():
|
|
|
107 |
num_steps = gr.Dropdown(
|
108 |
+
choices=[2**i for i in range(2, 10)],
|
109 |
+
value=16,
|
110 |
+
label="number of sampling steps (>256 is recommended)",
|
111 |
+
)
|
112 |
+
cmap_name = gr.Dropdown(
|
113 |
+
choices=plt.colormaps(),
|
114 |
+
value="turbo",
|
115 |
+
label="colormap for range/reflectance images",
|
116 |
)
|
117 |
btn = gr.Button(value="Generate random samples")
|
118 |
|
|
|
129 |
label="Reflectance image",
|
130 |
scale=1,
|
131 |
)
|
132 |
+
point_view = gr.Plot(
|
133 |
+
label="Point cloud",
|
134 |
+
scale=1,
|
135 |
+
)
|
136 |
|
137 |
btn.click(
|
138 |
generate,
|
139 |
+
inputs=[num_steps, cmap_name],
|
140 |
+
outputs=[range_view, rflct_view, point_view],
|
141 |
)
|
142 |
|
143 |
|
144 |
+
demo.queue()
|
145 |
demo.launch()
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ matplotlib
|
|
4 |
numpy
|
5 |
torch
|
6 |
torchvision
|
7 |
-
tqdm
|
|
|
|
4 |
numpy
|
5 |
torch
|
6 |
torchvision
|
7 |
+
tqdm
|
8 |
+
plotly
|