|
import copy
|
|
import gradio as gr
|
|
import os
|
|
import pickle
|
|
from matplotlib import pyplot as plt
|
|
import plotly.express as px
|
|
import plotly.graph_objects as go
|
|
import numpy as np
|
|
|
|
|
|
|
|
all_test_scenes = sorted(os.listdir('iso_output/NYU'))
|
|
|
|
|
|
def get_grid_coords(dims, resolution):
|
|
"""
|
|
:param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32])
|
|
:return coords_grid: is the center coords of voxels in the grid
|
|
"""
|
|
|
|
g_xx = np.arange(0, dims[0] + 1)
|
|
g_yy = np.arange(0, dims[1] + 1)
|
|
|
|
g_zz = np.arange(0, dims[2] + 1)
|
|
|
|
|
|
xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1])
|
|
coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T
|
|
|
|
|
|
coords_grid = (coords_grid * resolution) + resolution / 2
|
|
|
|
temp = np.copy(coords_grid)
|
|
temp[:, 0] = coords_grid[:, 1]
|
|
temp[:, 1] = coords_grid[:, 0]
|
|
coords_grid = np.copy(temp)
|
|
|
|
return coords_grid
|
|
|
|
|
|
def draw(
|
|
voxels,
|
|
cam_pose,
|
|
vox_origin,
|
|
voxel_size=0.08,
|
|
d=0.75,
|
|
):
|
|
|
|
y = d * 480 / (2 * 518.8579)
|
|
x = d * 640 / (2 * 518.8579)
|
|
tri_points = np.array(
|
|
[
|
|
[0, 0, 0],
|
|
[x, y, d],
|
|
[-x, y, d],
|
|
[-x, -y, d],
|
|
[x, -y, d],
|
|
]
|
|
)
|
|
tri_points = np.hstack([tri_points, np.ones((5, 1))])
|
|
|
|
tri_points = (cam_pose @ tri_points.T).T
|
|
x = tri_points[:, 0] - vox_origin[0]
|
|
y = tri_points[:, 1] - vox_origin[1]
|
|
z = tri_points[:, 2] - vox_origin[2]
|
|
triangles = [
|
|
(0, 1, 2),
|
|
(0, 1, 4),
|
|
(0, 3, 4),
|
|
(0, 2, 3),
|
|
]
|
|
|
|
|
|
grid_coords = get_grid_coords(
|
|
[voxels.shape[0], voxels.shape[2], voxels.shape[1]], voxel_size
|
|
)
|
|
|
|
|
|
grid_coords = np.vstack(
|
|
(grid_coords.T, np.moveaxis(voxels, [0, 1, 2], [0, 2, 1]).reshape(-1))
|
|
).T
|
|
|
|
|
|
occupied_voxels = grid_coords[(grid_coords[:, 3] > 0) & (grid_coords[:, 3] < 255)]
|
|
|
|
|
|
colors = np.array(
|
|
[
|
|
[22, 191, 206, 255],
|
|
[214, 38, 40, 255],
|
|
[43, 160, 43, 255],
|
|
[158, 216, 229, 255],
|
|
[114, 158, 206, 255],
|
|
[204, 204, 91, 255],
|
|
[255, 186, 119, 255],
|
|
[147, 102, 188, 255],
|
|
[30, 119, 181, 255],
|
|
[188, 188, 33, 255],
|
|
[255, 127, 12, 255],
|
|
[196, 175, 214, 255],
|
|
[153, 153, 153, 255],
|
|
[255, 255, 255, 255],
|
|
]
|
|
)
|
|
|
|
pts_colors = [f'rgb({colors[int(i)][0]}, {colors[int(i)][1]}, {colors[int(i)][2]})' for i in occupied_voxels[:, 3]]
|
|
|
|
|
|
|
|
fig = go.Figure(data=[go.Scatter3d(x=occupied_voxels[:, 0], y=occupied_voxels[:, 1], z=occupied_voxels[:, 2],mode='markers',
|
|
marker=dict(
|
|
size=5,
|
|
color=pts_colors,
|
|
|
|
opacity=1.0,
|
|
symbol='square'
|
|
))])
|
|
fig.update_layout(
|
|
autosize=True,
|
|
scene = dict(
|
|
aspectmode='data',
|
|
xaxis = dict(
|
|
backgroundcolor="rgb(255, 255, 255)",
|
|
gridcolor="black",
|
|
showbackground=True,
|
|
zerolinecolor="black",
|
|
nticks=4,
|
|
visible=False,
|
|
range=[-5,5],),
|
|
yaxis = dict(
|
|
backgroundcolor="rgb(255, 255, 255)",
|
|
gridcolor="black",
|
|
showbackground=True,
|
|
zerolinecolor="black",
|
|
visible=False,
|
|
nticks=4, range=[-5,5],),
|
|
zaxis = dict(
|
|
backgroundcolor="rgb(255, 255, 255)",
|
|
gridcolor="black",
|
|
showbackground=True,
|
|
zerolinecolor="black",
|
|
visible=False,
|
|
nticks=4, range=[-5,5],),
|
|
bgcolor="black",
|
|
),
|
|
|
|
)
|
|
|
|
return fig
|
|
|
|
|
|
def predict(scan):
|
|
if scan is None:
|
|
return None, None, None
|
|
scan = 'iso_output/NYU/' + scan
|
|
with open(scan, "rb") as handle:
|
|
b = pickle.load(handle)
|
|
|
|
cam_pose = b["cam_pose"]
|
|
vox_origin = b["vox_origin"]
|
|
gt_scene = b["target"]
|
|
pred_scene = b["y_pred"]
|
|
scan = os.path.basename(scan)[:12]
|
|
img = plt.imread('iso_input/'+scan+'_color.jpg')
|
|
|
|
pred_scene[(gt_scene == 255)] = 255
|
|
|
|
fig = draw(
|
|
pred_scene,
|
|
cam_pose,
|
|
vox_origin,
|
|
voxel_size=0.08,
|
|
d=0.75,
|
|
)
|
|
|
|
fig2 = draw(
|
|
gt_scene,
|
|
cam_pose,
|
|
vox_origin,
|
|
voxel_size=0.08,
|
|
d=0.75,
|
|
)
|
|
|
|
return fig, fig2, img
|
|
|
|
description = """
|
|
ISO Demo on NYUv2 test set.
|
|
|
|
For a fast rendering, we generate the output of test set scenes offline, and just provide a interface for plotting the output result.
|
|
We recommend you try visualization scripts locally in your computer for a better interaction.
|
|
|
|
<center>
|
|
<a href="https://hongxiaoy.github.io/ISO/">
|
|
<img style="display:inline" alt="Project page" src="https://img.shields.io/badge/Project%20Page-ISO-blue">
|
|
</a>
|
|
<a href="https://arxiv.org/abs/2407.11730"><img style="display:inline" src="https://img.shields.io/badge/arXiv-ISO-red"></a>
|
|
<a href="https://github.com/hongxiaoy/ISO"><img style="display:inline" src="https://img.shields.io/github/stars/hongxiaoy/ISO?style=social"></a>
|
|
</center>
|
|
"""
|
|
title = """
|
|
<center>
|
|
<h1>Monocular Occupancy Prediction for Scalable Indoor Scenes</h1>
|
|
</center>
|
|
"""
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown(title)
|
|
gr.Markdown(description)
|
|
with gr.Row():
|
|
with gr.Column():
|
|
input = gr.Dropdown(all_test_scenes, label='input scan')
|
|
submit_btn = gr.Button("Submit", render=True)
|
|
img = gr.Image(label='color image')
|
|
with gr.Column():
|
|
output = gr.Plot(label='prediction')
|
|
label = gr.Plot(label='ground truth')
|
|
|
|
submit_btn.click(fn=predict, inputs=input, outputs=[output, label, img])
|
|
|
|
|
|
|
|
demo.launch() |