File size: 10,477 Bytes
91126af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
#!/usr/bin/env python3
# Copyright © Niantic, Inc. 2023.

import logging
import shutil
from pathlib import Path
import os
import numpy as np
import argparse
from distutils.util import strtobool
import time
import ace_zero_util as zutil
from ace_util import load_npz_file, compute_knn_mask
from joblib import Parallel, delayed
# from ace_trainer import TrainerACE
# import dataset_io
from PIL import Image
from ace_visualizer import ACEVisualizer
import subprocess
import matplotlib
_logger = logging.getLogger(__name__)
import imageio



def colorize(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
    if mask is None:
        depth = np.where(depth > 0, depth, np.nan)
    else:
        depth = np.where((depth > 0) & mask, depth, np.nan)
    disp = 1 / depth
    if normalize:
        min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.999)
        disp = (disp - min_disp) / (max_disp - min_disp)
    colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp), 0)
    colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3]
    return colored

def _strtobool(x):
    return bool(strtobool(x))

def npz2image(npz_file, save_path):
    data = np.load(npz_file)
    os.makedirs(save_path, exist_ok=True)
    image_data = data['images_gt'] if 'images_gt' in data else data['images']
    if image_data.shape[1] == 3:
        image_data = np.transpose(image_data, (0, 2, 3, 1))
    for idx in range(image_data.shape[0]):
        img = image_data[idx, :, :]
        if img.max() <= 1.0 and img.min() >= 0.0:
            img = (img * 255).astype(np.uint8)
        elif img.max() <= 1.0 and img.min() < 0.0:
            img = (img + 1.0) * 255 / 2.0
            img = img.astype(np.uint8)
        else:
            img = img.astype(np.uint8)
        # import pdb; pdb.set_trace()
        image = Image.fromarray(img)
        image.save(f"{save_path}/gt_image_{idx}.png")
        # image.save(os.path.join(save_path, str(idx) + '.png'))


class Options:
    def __init__(self, ** kwargs):
        self.__dict__.update(kwargs)


if __name__ == '__main__':

    # Setup logging levels.
    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(
        description='Run ACE0 for a dataset or a scene.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--export_point_cloud', type=_strtobool, default=True,
                        help="Export the ACE0 point cloud after reconstruction, "
                             "for visualisation or to initialise splats")

    parser.add_argument('--dense_point_cloud', type=_strtobool, default=True,
                        help='when exporting a point cloud, do not filter points based on reprojection error, '
                             'bad for visualisation but good to initialise splats')

    parser.add_argument('--num_data_workers', type=int, default=12,
                        help='number of data loading workers, set according to the number of available CPU cores')

    #Registration parameters 
    parser.add_argument('--ransac_iterations', type=int, default=32,
                        help="Number of RANSAC hypothesis when registering mapping frames.")
    parser.add_argument('--ransac_threshold', type=float, default=10,
                        help='RANSAC inlier threshold in pixels')

    # Visualisation parameters
    parser.add_argument('--render_visualization', type=_strtobool, default=True,
                        help="Render visualisation frames of the whole reconstruction process.")
    parser.add_argument('--render_flipped_portrait', type=_strtobool, default=False,
                        help="Dataset images are 90deg flipped (like Wayspots).")
    parser.add_argument('--render_marker_size', type=float, default=0.03,
                        help="Size of the camera marker when rendering scenes.")
    # parser.add_argument('--iterations_output', type=int, default=500,
    #                     help='how often to print the loss and render a frame')

    parser.add_argument('--result_npz', type=str, help='Path to the .npz file containing the raw data (e.g., intrinsic matrices, 3D points, etc.) for visualization')
    parser.add_argument('--results_folder', type=Path, default="1217_new", help='path to output folder for result files')
    parser.add_argument('--knn_percent', type=float, default=98.0, help='percentage of points to keep in the knn mask')
    parser.add_argument('--K_neighbors', type=int, default=10, help='number of neighbors to use for knn mask')
    parser.add_argument('--pan_radius_scale', type=float, default=-1, help='factor to control the size of the pan camera')
    parser.add_argument('--pan_start_angle', type=int, default=-90, help='start angle of the pan camera')
    opt = parser.parse_args()
    opt.results_folder = '/'.join(opt.result_npz.split('/')[:-1])
    os.makedirs(opt.results_folder, exist_ok=True)
    opt.results_folder = Path(opt.results_folder)
    image_folder = f'{opt.results_folder}/gt_images'
    npz2image(opt.result_npz, image_folder)
    opt.rgb_files = f'{image_folder}/*.png'

    input_data = load_npz_file(opt.result_npz)
    
    if 'pts_mask' not in input_data.keys():
        start_time = time.time()
        input_data['pts_mask'] = compute_knn_mask(input_data['pts3d'], opt.K_neighbors, opt.knn_percent)
        end_time = time.time()
        _logger.info(f"Computing knn mask took {end_time - start_time:.2f} seconds.")
        npz_with_mask = f"{opt.results_folder}/points_with_mask.npz"
        np.savez_compressed(npz_with_mask, **input_data)
        opt.result_npz = npz_with_mask
    best_seed = 4
    iteration_id = zutil.get_seed_id(best_seed)
    # depth = input_data['pts3d'][0,...,-1]
    poses = input_data['cam_poses']
    H, W = input_data['pts3d'].shape[1:3]
    for idx, (pts3d, pose) in enumerate(zip(input_data['pts3d'], poses)):
        w2c = np.linalg.inv(pose)
        R = w2c[:3,:3]
        T = w2c[:3,-1] 
        pts3d_c2w = np.einsum('kl, Nl -> Nk', R, pts3d.reshape(-1,3)) + T[None]
        disp = pts3d_c2w[...,-1]
        disp = disp.reshape(H, W)
        disp_vis = colorize(disp)
        os.makedirs(f"{opt.results_folder}/depth", exist_ok=True)
        imageio.imwrite(f"{opt.results_folder}/depth/depth_vis_{idx}.png", disp_vis)
        print(f"Visualized depth map is saved to {opt.results_folder}/depth/depth_vis_{idx}.png.")
    register_opt = zutil.get_register_opt(rgb_files=opt.rgb_files, result_npz=opt.result_npz)
    modify_options = {
        'rgb_files': opt.rgb_files,
        'render_visualization': opt.render_visualization,
        'render_target_path': zutil.get_render_path(opt.results_folder),
        'render_marker_size': opt.render_marker_size,
        'render_flipped_portrait': opt.render_flipped_portrait,
        'session': f"{iteration_id}",
        'hypotheses': opt.ransac_iterations,
        'threshold': opt.ransac_threshold,
        'hypotheses_max_tries': 16,
        'result_npz': opt.result_npz,
        'only_frustum': True,
        'pan_radius_scale': opt.pan_radius_scale,
        'pan_start_angle': opt.pan_start_angle
    }
    for k, v in modify_options.items():
        setattr(register_opt, k, v)
    
    reg_state_dict_1 = zutil.regitser_visulization(register_opt)

    scheduled_to_stop_early = False
    prev_iteration_id = iteration_id

    register_opt_iter = zutil.get_register_opt(rgb_files=opt.rgb_files, result_npz=opt.result_npz)
    modify_options = {
        'rgb_files': opt.rgb_files,
        'render_visualization': opt.render_visualization,
        'render_target_path': zutil.get_render_path(opt.results_folder),
        'render_marker_size': opt.render_marker_size,
        'render_flipped_portrait': opt.render_flipped_portrait,
        'hypotheses': opt.ransac_iterations,
        'threshold': opt.ransac_threshold,
        'hypotheses_max_tries': 16,
        'result_npz': opt.result_npz,
        'state_dict': reg_state_dict_1,
        'only_frustum': False,
        'pan_radius_scale': opt.pan_radius_scale,
        'pan_start_angle': opt.pan_start_angle
    }
    for k, v in modify_options.items():
        setattr(register_opt_iter, k, v)

    reg_state_dict_2 = zutil.regitser_visulization(register_opt_iter)

    _logger.info("Rendering final sweep.")

    final_sweep_visualizer = ACEVisualizer(
        zutil.get_render_path(opt.results_folder), 
        flipped_portait=False, 
        map_depth_filter=100,
        marker_size=opt.render_marker_size,
        result_npz=opt.result_npz, 
        pan_start_angle=opt.pan_start_angle, 
        pan_radius_scale=opt.pan_radius_scale
    )
    poses = [final_sweep_visualizer.pose_align(pose) for pose in final_sweep_visualizer.cam_pose]
    rgb_path = f'{opt.results_folder}/gt_images'
    rgb_files = os.listdir(rgb_path)
    pose_dict = {str(rgb_file): 0 for rgb_file in rgb_files}
    pose_iterations = [0 for _ in range(len(rgb_files))]
    # import pdb; pdb.set_trace()
    final_sweep_visualizer.render_final_sweep(
        frame_count=150,
        camera_z_offset=4,
        poses=poses,
        pose_iterations=pose_iterations,
        total_poses=len(pose_dict),
        state_dict=reg_state_dict_2,)

    _logger.info("Converting to video.")

    # get ffmpeg path
    ffmpeg_path = shutil.which("ffmpeg")

    # run ffmpeg to convert the rendered images to a video
    ffmpeg_save_cmd = [ffmpeg_path,
                    "-y",
                    "-framerate", "30",
                    "-pattern_type", "glob",
                    "-i", f"{zutil.get_render_path(opt.results_folder)}/*.png",
                    "-c:v", "libx264",
                    "-pix_fmt", "yuv420p",
                    str(opt.results_folder / "reconstruction.mp4")
                    ]
    
    subprocess.run(ffmpeg_save_cmd, check=True)
    
    print(f'The render result video is saved to {opt.results_folder}/reconstruction.mp4.')
    if opt.export_point_cloud:
        # import pdb; pdb.set_trace()
        import trimesh
        pts = final_sweep_visualizer.pts3d.reshape(-1, 3) 
        image = final_sweep_visualizer.image_gt.transpose(0, 2, 3, 1).reshape(-1, 3)
        clr = (( image + 1.0 ) / 2.0 * 255.0).astype('float64')
        cloud = trimesh.PointCloud(pts, colors=clr)
        cloud.export(f"{opt.results_folder}/point_cloud.ply")
        print(f'The point cloud is saved to {opt.results_folder}/point_cloud.ply.')