Kazuto Nakashima commited on
Commit
ffb8a4e
·
1 Parent(s): 1ee5104
Files changed (2) hide show
  1. app.py +86 -17
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import gradio as gr
2
  import matplotlib.cm as cm
 
3
  import numpy as np
4
  import torch
5
  import torch.nn.functional as F
 
 
6
 
7
  torch.set_grad_enabled(False)
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -21,34 +24,95 @@ def colorize(tensor, cmap_fn=cm.turbo):
21
  return tensor
22
 
23
 
24
- @torch.no_grad()
25
- def generate(num_steps) -> str:
26
- output = ddpm.sample(batch_size=1, num_steps=int(num_steps))
27
  output = lidar_utils.denormalize(output.clamp(-1, 1))
28
- range_image = lidar_utils.revert_depth(output[:, [0]])
29
- range_image = (range_image / lidar_utils.max_depth).clamp(0, 1)
30
- reflectance_image = output[:, [1]]
31
- range_image = colorize(range_image)[0].permute(1, 2, 0)
32
- reflectance_image = colorize(reflectance_image)[0].permute(1, 2, 0)
33
- return range_image.cpu().numpy(), reflectance_image.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  with gr.Blocks() as demo:
37
  gr.Markdown(
38
  """
39
- # R2DM Demo
40
- **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
 
41
  Kazuto Nakashima, Ryo Kurazume<br>
42
  [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
43
  """
44
  )
45
  with gr.Row():
46
  with gr.Column():
47
- gr.Text(f"Device: {device}", label="device")
48
  num_steps = gr.Dropdown(
49
- choices=[2**i for i in range(3, 11)],
50
- value=8,
51
- label="number of sampling steps",
 
 
 
 
 
52
  )
53
  btn = gr.Button(value="Generate random samples")
54
 
@@ -65,12 +129,17 @@ with gr.Blocks() as demo:
65
  label="Reflectance image",
66
  scale=1,
67
  )
 
 
 
 
68
 
69
  btn.click(
70
  generate,
71
- inputs=[num_steps],
72
- outputs=[range_view, rflct_view],
73
  )
74
 
75
 
 
76
  demo.launch()
 
1
  import gradio as gr
2
  import matplotlib.cm as cm
3
+ import matplotlib.pyplot as plt
4
  import numpy as np
5
  import torch
6
  import torch.nn.functional as F
7
+ import einops
8
+ import plotly.graph_objects as go
9
 
10
  torch.set_grad_enabled(False)
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
24
  return tensor
25
 
26
 
27
+ def render_point_cloud(output, cmap):
 
 
28
  output = lidar_utils.denormalize(output.clamp(-1, 1))
29
+ depth = lidar_utils.revert_depth(output[:, [0]])
30
+ rflct = output[:, [1]]
31
+ point = lidar_utils.to_xyz(depth).cpu().numpy()
32
+ point = einops.rearrange(point, "1 c h w -> c (h w)")
33
+ angle = lidar_utils.ray_angles.rad2deg()
34
+ label = [
35
+ f"depth: {float(d):.2f}m<br>"
36
+ + f"reflectance: {float(r):.2f}<br>"
37
+ + f"elevation: {float(e):.2f}°<br>"
38
+ + f"azimuth: {float(a):.2f}°"
39
+ for d, r, e, a in zip(
40
+ einops.rearrange(depth, "1 1 h w -> (h w)"),
41
+ einops.rearrange(rflct, "1 1 h w -> (h w)"),
42
+ einops.rearrange(angle[0, 0], "h w -> (h w)"),
43
+ einops.rearrange(angle[0, 1], "h w -> (h w)"),
44
+ )
45
+ ]
46
+ fig = go.Figure(
47
+ data=[
48
+ go.Scatter3d(
49
+ x=-point[0],
50
+ y=-point[1],
51
+ z=point[2],
52
+ mode="markers",
53
+ marker=dict(
54
+ size=1,
55
+ color=point[2],
56
+ colorscale="viridis",
57
+ autocolorscale=False,
58
+ cauto=False,
59
+ cmin=-2,
60
+ cmax=0.5,
61
+ ),
62
+ text=label,
63
+ hoverinfo="text",
64
+ )
65
+ ],
66
+ layout=dict(
67
+ scene=dict(
68
+ xaxis=dict(showticklabels=False, visible=False),
69
+ yaxis=dict(showticklabels=False, visible=False),
70
+ zaxis=dict(showticklabels=False, visible=False),
71
+ aspectmode="data",
72
+ ),
73
+ margin=dict(l=0, r=0, b=0, t=0),
74
+ paper_bgcolor="rgba(0,0,0,0)",
75
+ plot_bgcolor="rgba(0,0,0,0)",
76
+ ),
77
+ )
78
+ depth = depth / lidar_utils.max_depth
79
+ depth = colorize(depth, cmap)[0].permute(1, 2, 0).cpu().numpy()
80
+ rflct = colorize(rflct, cmap)[0].permute(1, 2, 0).cpu().numpy()
81
+ return depth, rflct, fig
82
+
83
+
84
+ def generate(num_steps, cmap_name, progress=gr.Progress()):
85
+ num_steps = int(num_steps)
86
+ x = ddpm.randn(1, *ddpm.sampling_shape, device=ddpm.device)
87
+ steps = torch.linspace(1.0, 0.0, num_steps + 1, device=ddpm.device)[None]
88
+ for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
89
+ step_t = steps[:, i]
90
+ step_s = steps[:, i + 1]
91
+ x = ddpm.p_sample(x, step_t, step_s)
92
+ return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name))
93
 
94
 
95
  with gr.Blocks() as demo:
96
  gr.Markdown(
97
  """
98
+ # R2DM
99
+ R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation.
100
+ > **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
101
  Kazuto Nakashima, Ryo Kurazume<br>
102
  [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
103
  """
104
  )
105
  with gr.Row():
106
  with gr.Column():
 
107
  num_steps = gr.Dropdown(
108
+ choices=[2**i for i in range(2, 10)],
109
+ value=16,
110
+ label="number of sampling steps (>256 is recommended)",
111
+ )
112
+ cmap_name = gr.Dropdown(
113
+ choices=plt.colormaps(),
114
+ value="turbo",
115
+ label="colormap for range/reflectance images",
116
  )
117
  btn = gr.Button(value="Generate random samples")
118
 
 
129
  label="Reflectance image",
130
  scale=1,
131
  )
132
+ point_view = gr.Plot(
133
+ label="Point cloud",
134
+ scale=1,
135
+ )
136
 
137
  btn.click(
138
  generate,
139
+ inputs=[num_steps, cmap_name],
140
+ outputs=[range_view, rflct_view, point_view],
141
  )
142
 
143
 
144
+ demo.queue()
145
  demo.launch()
requirements.txt CHANGED
@@ -4,4 +4,5 @@ matplotlib
4
  numpy
5
  torch
6
  torchvision
7
- tqdm
 
 
4
  numpy
5
  torch
6
  torchvision
7
+ tqdm
8
+ plotly