ISO / app.py
hongxiaoy's picture
Upload app.py
a57b6ab verified
import copy
import gradio as gr
import os
import pickle
from matplotlib import pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
all_test_scenes = sorted(os.listdir('iso_output/NYU'))
def get_grid_coords(dims, resolution):
"""
:param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32])
:return coords_grid: is the center coords of voxels in the grid
"""
g_xx = np.arange(0, dims[0] + 1)
g_yy = np.arange(0, dims[1] + 1)
g_zz = np.arange(0, dims[2] + 1)
# Obtaining the grid with coords...
xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1])
coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T
# coords_grid = coords_grid.astype(np.float)
coords_grid = (coords_grid * resolution) + resolution / 2
temp = np.copy(coords_grid)
temp[:, 0] = coords_grid[:, 1]
temp[:, 1] = coords_grid[:, 0]
coords_grid = np.copy(temp)
return coords_grid
def draw(
voxels,
cam_pose,
vox_origin,
voxel_size=0.08,
d=0.75, # 0.75m - determine the size of the mesh representing the camera
):
# Compute the coordinates of the mesh representing camera
y = d * 480 / (2 * 518.8579)
x = d * 640 / (2 * 518.8579)
tri_points = np.array(
[
[0, 0, 0],
[x, y, d],
[-x, y, d],
[-x, -y, d],
[x, -y, d],
]
)
tri_points = np.hstack([tri_points, np.ones((5, 1))])
tri_points = (cam_pose @ tri_points.T).T
x = tri_points[:, 0] - vox_origin[0]
y = tri_points[:, 1] - vox_origin[1]
z = tri_points[:, 2] - vox_origin[2]
triangles = [
(0, 1, 2),
(0, 1, 4),
(0, 3, 4),
(0, 2, 3),
]
# Compute the voxels coordinates
grid_coords = get_grid_coords(
[voxels.shape[0], voxels.shape[2], voxels.shape[1]], voxel_size
)
# Attach the predicted class to every voxel
grid_coords = np.vstack(
(grid_coords.T, np.moveaxis(voxels, [0, 1, 2], [0, 2, 1]).reshape(-1))
).T
# Remove empty and unknown voxels
occupied_voxels = grid_coords[(grid_coords[:, 3] > 0) & (grid_coords[:, 3] < 255)]
colors = np.array(
[
[22, 191, 206, 255],
[214, 38, 40, 255],
[43, 160, 43, 255],
[158, 216, 229, 255],
[114, 158, 206, 255],
[204, 204, 91, 255],
[255, 186, 119, 255],
[147, 102, 188, 255],
[30, 119, 181, 255],
[188, 188, 33, 255],
[255, 127, 12, 255],
[196, 175, 214, 255],
[153, 153, 153, 255],
[255, 255, 255, 255],
]
)
pts_colors = [f'rgb({colors[int(i)][0]}, {colors[int(i)][1]}, {colors[int(i)][2]})' for i in occupied_voxels[:, 3]]
fig = go.Figure(data=[go.Scatter3d(x=occupied_voxels[:, 0], y=occupied_voxels[:, 1], z=occupied_voxels[:, 2],mode='markers',
marker=dict(
size=5,
color=pts_colors, # set color to an array/list of desired values
# colorscale='Viridis', # choose a colorscale
opacity=1.0,
symbol='square'
))])
fig.update_layout(
autosize=True,
scene = dict(
aspectmode='data',
xaxis = dict(
backgroundcolor="rgb(255, 255, 255)",
gridcolor="black",
showbackground=True,
zerolinecolor="black",
nticks=4,
visible=False,
range=[-5,5],),
yaxis = dict(
backgroundcolor="rgb(255, 255, 255)",
gridcolor="black",
showbackground=True,
zerolinecolor="black",
visible=False,
nticks=4, range=[-5,5],),
zaxis = dict(
backgroundcolor="rgb(255, 255, 255)",
gridcolor="black",
showbackground=True,
zerolinecolor="black",
visible=False,
nticks=4, range=[-5,5],),
bgcolor="black",
),
)
return fig
def predict(scan):
if scan is None:
return None, None, None
scan = 'iso_output/NYU/' + scan
with open(scan, "rb") as handle:
b = pickle.load(handle)
cam_pose = b["cam_pose"]
vox_origin = b["vox_origin"]
gt_scene = b["target"]
pred_scene = b["y_pred"]
scan = os.path.basename(scan)[:12]
img = plt.imread('iso_input/'+scan+'_color.jpg')
pred_scene[(gt_scene == 255)] = 255 # only draw scene inside the room
fig = draw(
pred_scene,
cam_pose,
vox_origin,
voxel_size=0.08,
d=0.75,
)
fig2 = draw(
gt_scene,
cam_pose,
vox_origin,
voxel_size=0.08,
d=0.75,
)
return fig, fig2, img
description = """
ISO Demo on NYUv2 test set.
For a fast rendering, we generate the output of test set scenes offline, and just provide a interface for plotting the output result.
We recommend you try visualization scripts locally in your computer for a better interaction.
<center>
<a href="https://hongxiaoy.github.io/ISO/">
<img style="display:inline" alt="Project page" src="https://img.shields.io/badge/Project%20Page-ISO-blue">
</a>
<a href="https://arxiv.org/abs/2407.11730"><img style="display:inline" src="https://img.shields.io/badge/arXiv-ISO-red"></a>
<a href="https://github.com/hongxiaoy/ISO"><img style="display:inline" src="https://img.shields.io/github/stars/hongxiaoy/ISO?style=social"></a>
</center>
"""
title = """
<center>
<h1>Monocular Occupancy Prediction for Scalable Indoor Scenes</h1>
</center>
"""
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
input = gr.Dropdown(all_test_scenes, label='input scan')
submit_btn = gr.Button("Submit", render=True)
img = gr.Image(label='color image')
with gr.Column():
output = gr.Plot(label='prediction')
label = gr.Plot(label='ground truth')
submit_btn.click(fn=predict, inputs=input, outputs=[output, label, img])
# demo = gr.Interface(fn=predict, inputs=gr.Dropdown(all_test_scenes), outputs=gr.Plot(), title=title, description=description)
demo.launch()