import subprocess import os import torch import sys def install_cuda_toolkit(): # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run" CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run" # ! cu121 already installed CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) os.environ["CUDA_HOME"] = "/usr/local/cuda" os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( os.environ["CUDA_HOME"], "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], ) # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" # install_cuda_toolkit() # to compile the dependencies # # pyt_version_str=torch.__version__.split("+")[0].replace(".", "") # version_str="".join([ # f"py3{sys.version_info.minor}_cu", # torch.version.cuda.replace(".",""), # f"_pyt{pyt_version_str}" # ]) # install pytorch3d with the right version os.system('pip install iopath') # os.system('FORCE_CUDA=1 pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"') pyt_version_str=torch.__version__.split("+")[0].replace(".", "") version_str="".join([ f"py3{sys.version_info.minor}_cu", torch.version.cuda.replace(".",""), f"_pyt{pyt_version_str}" ]) # install pytorch3d with the right version # os.system('pip install iopath') # os.system("pip install -U 'git+https://github.com/facebookresearch/fvcore'") # os.system("pip uninstall fvcore -y") # os.system("pip install -U --no-deps fvcore") # os.system(f'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html') # print(f'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html') import spaces import mast3r.utils.path_to_dust3r # noqa import dust3r.utils.path_to_croco # noqa: F401 import mast3r.utils.path_to_dust3r # noqa import sys import os.path as path import torch import tempfile import gradio import shutil import math from mast3r.model import AsymmetricMASt3R import matplotlib.pyplot as pl from dust3r.utils.image import load_images import torch.nn.functional as F from dust3r.utils.geometry import xy_grid import numpy as np import cv2 from dust3r.utils.device import to_numpy import trimesh from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes from scipy.spatial.transform import Rotation pl.ion() # for gpu >= Ampere and pytorch >= 1.12 torch.backends.cuda.matmul.allow_tf32 = True batch_size = 1 inf = float('inf') # weights_path = "checkpoints/geometry_pose.pth" device = 'cuda' if torch.cuda.is_available() else 'cpu' # ckpt = torch.load(weights_path, map_location=device) model = AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True, desc_conf_mode=('exp', 0, inf)) model = AsymmetricMASt3R.from_pretrained("zhang3z/FLARE").to(device) # model.from_pretrained(ckpt['model'], strict=False) model = model.to(device).eval() image_size = 512 silent = True gradio_delete_cache = 7200 backbone = torch.hub.load( "facebookresearch/dinov2", "dinov2_vitb14_reg" ) backbone = backbone.eval().cuda() class FileState: def __init__(self, outfile_name=None): self.outfile_name = outfile_name def __del__(self): if self.outfile_name is not None and os.path.isfile(self.outfile_name): os.remove(self.outfile_name) self.outfile_name = None def pad_to_square(reshaped_image): B, C, H, W = reshaped_image.shape max_dim = max(H, W) pad_height = max_dim - H pad_width = max_dim - W padding = (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2) padded_image = F.pad(reshaped_image, padding, mode='constant', value=0) return padded_image def generate_rank_by_dino( reshaped_image, backbone, query_frame_num, image_size=336 ): # Downsample image to image_size x image_size # because we found it is unnecessary to use high resolution rgbs = pad_to_square(reshaped_image) rgbs = F.interpolate( reshaped_image, (image_size, image_size), mode="bilinear", align_corners=True, ) rgbs = _resnet_normalize_image(rgbs.cuda()) # Get the image features (patch level) frame_feat = backbone(rgbs, is_training=True) frame_feat = frame_feat["x_norm_patchtokens"] frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) # Compute the similiarty matrix frame_feat_norm = frame_feat_norm.permute(1, 0, 2) similarity_matrix = torch.bmm( frame_feat_norm, frame_feat_norm.transpose(-1, -2) ) similarity_matrix = similarity_matrix.mean(dim=0) distance_matrix = 100 - similarity_matrix.clone() # Ignore self-pairing similarity_matrix.fill_diagonal_(-100) similarity_sum = similarity_matrix.sum(dim=1) # Find the most common frame most_common_frame_index = torch.argmax(similarity_sum).item() return most_common_frame_index _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] _resnet_mean = torch.tensor(_RESNET_MEAN).view(1, 3, 1, 1).cuda() _resnet_std = torch.tensor(_RESNET_STD).view(1, 3, 1, 1).cuda() def _resnet_normalize_image(img: torch.Tensor) -> torch.Tensor: return (img - _resnet_mean) / _resnet_std def calculate_index_mappings(query_index, S, device=None): """ Construct an order that we can switch [query_index] and [0] so that the content of query_index would be placed at [0] """ new_order = torch.arange(S) new_order[0] = query_index new_order[query_index] = 0 if device is not None: new_order = new_order.to(device) return new_order def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05, cam_color=None, as_pointcloud=False, transparent_cams=False, silent=False): assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) pts3d = to_numpy(pts3d) imgs = to_numpy(imgs) focals = to_numpy(focals) mask = to_numpy(mask) cams2world = to_numpy(cams2world) scene = trimesh.Scene() # full pointcloud if as_pointcloud: pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]).reshape(-1, 3) col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3) valid_msk = np.isfinite(pts.sum(axis=1)) pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk]) scene.add_geometry(pct) else: meshes = [] for i in range(len(imgs)): pts3d_i = pts3d[i].reshape(imgs[i].shape) msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1)) meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i)) mesh = trimesh.Trimesh(**cat_meshes(meshes)) scene.add_geometry(mesh) # add each camera for i, pose_c2w in enumerate(cams2world): if isinstance(cam_color, list): camera_edge_color = cam_color[i] else: camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] add_scene_cam(scene, pose_c2w, camera_edge_color, None if transparent_cams else imgs[i], focals[i], imsize=imgs[i].shape[1::-1], screen_width=cam_size) rot = np.eye(4) rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot)) if not silent: print('(exporting 3D scene to', outfile, ')') scene.export(file_obj=outfile) return outfile class FileState: def __init__(self, outfile_name=None): self.outfile_name = outfile_name def __del__(self): if self.outfile_name is not None and os.path.isfile(self.outfile_name): os.remove(self.outfile_name) self.outfile_name = None @spaces.GPU(duration=180) def local_get_reconstructed_scene(inputfiles, min_conf_thr, cam_size): # import sys # import torch # pyt_version_str=torch.__version__.split("+")[0].replace(".", "") # version_str="".join([ # f"py3{sys.version_info.minor}_cu", # torch.version.cuda.replace(".",""), # f"_pyt{pyt_version_str}" # ]) # os.system('pip install iopath') # print(f"pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html") from pytorch3d.ops import knn_points outdir = tempfile.mkdtemp(suffix='_FLARE_gradio_demo') batch = load_images(inputfiles, size=image_size, verbose=not silent) images = [gt['img'] for gt in batch] images = torch.cat(images, dim=0) images = images / 2 + 0.5 index = generate_rank_by_dino(images, backbone, query_frame_num=1) sorted_order = calculate_index_mappings(index, len(images), device=device) sorted_batch = [] for i in range(len(batch)): sorted_batch.append(batch[sorted_order[i]]) batch = sorted_batch ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'rng', 'vid']) ignore_dtype_keys = set(['true_shape', 'camera_pose', 'pts3d', 'fxfycxcy', 'img_org', 'camera_intrinsics', 'depthmap', 'depth_anything', 'fxfycxcy_unorm']) dtype = torch.bfloat16 for view in batch: for name in view.keys(): # pseudo_focal if name in ignore_keys: continue if isinstance(view[name], torch.Tensor): view[name] = view[name].to(device, non_blocking=True) else: view[name] = torch.tensor(view[name]).to(device, non_blocking=True) if view[name].dtype == torch.float32 and name not in ignore_dtype_keys: view[name] = view[name].to(dtype) view1 = batch[:1] view2 = batch[1:] with torch.cuda.amp.autocast(enabled=True, dtype=dtype): pred1, pred2, pred_cameras = model(view1, view2, True, dtype) pts3d = pred2['pts3d'] conf = pred2['conf'] pts3d = pts3d.detach().cpu() B, N, H, W, _ = pts3d.shape thres = torch.quantile(conf.flatten(2,3), min_conf_thr, dim=-1)[0] masks_conf = conf > thres[None, :, None, None] masks_conf = masks_conf.cpu() images = [view['img'] for view in view1+view2] shape = torch.stack([view['true_shape'] for view in view1+view2], dim=1).detach().cpu().numpy() images = torch.stack(images,1).float().permute(0,1,3,4,2).detach().cpu().numpy() images = images / 2 + 0.5 images = images.reshape(B, N, H, W, 3) # estimate focal length images = images[0] pts3d = pts3d[0] masks_conf = masks_conf[0] xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1) pp = torch.tensor((W/2, H/2)).to(xy_over_z) pixels = xy_grid(W, H, device=xy_over_z.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2 u, v = pixels[:1].unbind(dim=-1) x, y, z = pts3d[:1].reshape(-1,3).unbind(dim=-1) fx_votes = (u * z) / x fy_votes = (v * z) / y # assume square pixels, hence same focal for X and Y f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) focal = torch.nanmedian(f_votes, dim=-1).values focal = focal.item() pts3d = pts3d.numpy() # use PNP to estimate camera poses pred_poses = [] for i in range(pts3d.shape[0]): shape_input_each = shape[:, i] mesh_grid = xy_grid(shape_input_each[0,1], shape_input_each[0,0]) cur_inlier = conf[0,i] > torch.quantile(conf[0,i], 0.6) cur_inlier = cur_inlier.detach().cpu().numpy() ransac_thres = 0.5 confidence = 0.9999 iterationsCount = 10_000 cur_pts3d = pts3d[i] K = np.float32([(focal, 0, W/2), (0, focal, H/2), (0, 0, 1)]) success, r_pose, t_pose, _ = cv2.solvePnPRansac(cur_pts3d[cur_inlier].astype(np.float64), mesh_grid[cur_inlier].astype(np.float64), K, None, flags=cv2.SOLVEPNP_SQPNP, iterationsCount=iterationsCount, reprojectionError=1, confidence=confidence) r_pose = cv2.Rodrigues(r_pose)[0] RT = np.r_[np.c_[r_pose, t_pose], [(0,0,0,1)]] cam2world = np.linalg.inv(RT) pred_poses.append(cam2world) pred_poses = np.stack(pred_poses, axis=0) pred_poses = torch.tensor(pred_poses) # use knn to clean the point cloud K = 10 print('Cleaning point cloud with knn...') points = torch.tensor(pts3d.reshape(1,-1,3)).cuda() # knn = knn_points(points, points, K=K) # dists = knn.dists # mean_dists = dists.mean(dim=-1) # masks_dist = mean_dists < torch.quantile(mean_dists.reshape(-1), 0.95) # masks_dist = masks_dist.detach().cpu().numpy() # masks_conf = (masks_conf > 0) & masks_dist.reshape(-1,H,W) masks_conf = masks_conf > 0 os.makedirs(outdir, exist_ok=True) focals = [focal] * len(images) outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir) _convert_scene_output_to_glb(outfile_name, images, pts3d, masks_conf, focals, pred_poses, as_pointcloud=True, transparent_cams=False, cam_size=cam_size, silent=silent) return outfile_name css = """.gradio-container {margin: 0 !important; min-width: 100%};""" title = "FLARE Demo" # import sys # import torch # os.system('pip uninstall -y pytorch3d') with gradio.Blocks(css=css, title=title, delete_cache=(gradio_delete_cache, gradio_delete_cache)) as demo: # filestate = gradio.State(None) gradio.HTML('

3D Reconstruction with FLARE

') with gradio.Column(): inputfiles = gradio.File(file_count="multiple") snapshot = gradio.Image(None, visible=False) with gradio.Row(): # adjust the confidence threshold min_conf_thr = gradio.Slider(label="min_conf_thr", value=0.1, minimum=0.0, maximum=1, step=0.05) # adjust the camera size in the output pointcloud cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001) run_btn = gradio.Button("Run") outmodel = gradio.Model3D() run_btn.click(fn=local_get_reconstructed_scene, inputs=[inputfiles, min_conf_thr, cam_size], outputs=[outmodel]) demo.launch(show_error=True, share=None, server_name=None, server_port=None) shutil.rmtree(tmpdirname)