ISO / app.py
hongxiaoy's picture
Upload app.py
a57b6ab verified
raw
history blame
6.72 kB
import copy
import gradio as gr
import os
import pickle
from matplotlib import pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
all_test_scenes = sorted(os.listdir('iso_output/NYU'))
def get_grid_coords(dims, resolution):
"""
:param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32])
:return coords_grid: is the center coords of voxels in the grid
"""
g_xx = np.arange(0, dims[0] + 1)
g_yy = np.arange(0, dims[1] + 1)
g_zz = np.arange(0, dims[2] + 1)
# Obtaining the grid with coords...
xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1])
coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T
# coords_grid = coords_grid.astype(np.float)
coords_grid = (coords_grid * resolution) + resolution / 2
temp = np.copy(coords_grid)
temp[:, 0] = coords_grid[:, 1]
temp[:, 1] = coords_grid[:, 0]
coords_grid = np.copy(temp)
return coords_grid
def draw(
voxels,
cam_pose,
vox_origin,
voxel_size=0.08,
d=0.75, # 0.75m - determine the size of the mesh representing the camera
):
# Compute the coordinates of the mesh representing camera
y = d * 480 / (2 * 518.8579)
x = d * 640 / (2 * 518.8579)
tri_points = np.array(
[
[0, 0, 0],
[x, y, d],
[-x, y, d],
[-x, -y, d],
[x, -y, d],
]
)
tri_points = np.hstack([tri_points, np.ones((5, 1))])
tri_points = (cam_pose @ tri_points.T).T
x = tri_points[:, 0] - vox_origin[0]
y = tri_points[:, 1] - vox_origin[1]
z = tri_points[:, 2] - vox_origin[2]
triangles = [
(0, 1, 2),
(0, 1, 4),
(0, 3, 4),
(0, 2, 3),
]
# Compute the voxels coordinates
grid_coords = get_grid_coords(
[voxels.shape[0], voxels.shape[2], voxels.shape[1]], voxel_size
)
# Attach the predicted class to every voxel
grid_coords = np.vstack(
(grid_coords.T, np.moveaxis(voxels, [0, 1, 2], [0, 2, 1]).reshape(-1))
).T
# Remove empty and unknown voxels
occupied_voxels = grid_coords[(grid_coords[:, 3] > 0) & (grid_coords[:, 3] < 255)]
colors = np.array(
[
[22, 191, 206, 255],
[214, 38, 40, 255],
[43, 160, 43, 255],
[158, 216, 229, 255],
[114, 158, 206, 255],
[204, 204, 91, 255],
[255, 186, 119, 255],
[147, 102, 188, 255],
[30, 119, 181, 255],
[188, 188, 33, 255],
[255, 127, 12, 255],
[196, 175, 214, 255],
[153, 153, 153, 255],
[255, 255, 255, 255],
]
)
pts_colors = [f'rgb({colors[int(i)][0]}, {colors[int(i)][1]}, {colors[int(i)][2]})' for i in occupied_voxels[:, 3]]
fig = go.Figure(data=[go.Scatter3d(x=occupied_voxels[:, 0], y=occupied_voxels[:, 1], z=occupied_voxels[:, 2],mode='markers',
marker=dict(
size=5,
color=pts_colors, # set color to an array/list of desired values
# colorscale='Viridis', # choose a colorscale
opacity=1.0,
symbol='square'
))])
fig.update_layout(
autosize=True,
scene = dict(
aspectmode='data',
xaxis = dict(
backgroundcolor="rgb(255, 255, 255)",
gridcolor="black",
showbackground=True,
zerolinecolor="black",
nticks=4,
visible=False,
range=[-5,5],),
yaxis = dict(
backgroundcolor="rgb(255, 255, 255)",
gridcolor="black",
showbackground=True,
zerolinecolor="black",
visible=False,
nticks=4, range=[-5,5],),
zaxis = dict(
backgroundcolor="rgb(255, 255, 255)",
gridcolor="black",
showbackground=True,
zerolinecolor="black",
visible=False,
nticks=4, range=[-5,5],),
bgcolor="black",
),
)
return fig
def predict(scan):
if scan is None:
return None, None, None
scan = 'iso_output/NYU/' + scan
with open(scan, "rb") as handle:
b = pickle.load(handle)
cam_pose = b["cam_pose"]
vox_origin = b["vox_origin"]
gt_scene = b["target"]
pred_scene = b["y_pred"]
scan = os.path.basename(scan)[:12]
img = plt.imread('iso_input/'+scan+'_color.jpg')
pred_scene[(gt_scene == 255)] = 255 # only draw scene inside the room
fig = draw(
pred_scene,
cam_pose,
vox_origin,
voxel_size=0.08,
d=0.75,
)
fig2 = draw(
gt_scene,
cam_pose,
vox_origin,
voxel_size=0.08,
d=0.75,
)
return fig, fig2, img
description = """
ISO Demo on NYUv2 test set.
For a fast rendering, we generate the output of test set scenes offline, and just provide a interface for plotting the output result.
We recommend you try visualization scripts locally in your computer for a better interaction.
<center>
<a href="https://hongxiaoy.github.io/ISO/">
<img style="display:inline" alt="Project page" src="https://img.shields.io/badge/Project%20Page-ISO-blue">
</a>
<a href="https://arxiv.org/abs/2407.11730"><img style="display:inline" src="https://img.shields.io/badge/arXiv-ISO-red"></a>
<a href="https://github.com/hongxiaoy/ISO"><img style="display:inline" src="https://img.shields.io/github/stars/hongxiaoy/ISO?style=social"></a>
</center>
"""
title = """
<center>
<h1>Monocular Occupancy Prediction for Scalable Indoor Scenes</h1>
</center>
"""
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
input = gr.Dropdown(all_test_scenes, label='input scan')
submit_btn = gr.Button("Submit", render=True)
img = gr.Image(label='color image')
with gr.Column():
output = gr.Plot(label='prediction')
label = gr.Plot(label='ground truth')
submit_btn.click(fn=predict, inputs=input, outputs=[output, label, img])
# demo = gr.Interface(fn=predict, inputs=gr.Dropdown(all_test_scenes), outputs=gr.Plot(), title=title, description=description)
demo.launch()