Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. | |
from ast import Dict | |
import math | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch_scatter import scatter_mean #, scatter_max | |
from .unet_3daware import setup_unet #UNetTriplane3dAware | |
from .conv_pointnet import ConvPointnet | |
from .pc_encoder import PVCNNEncoder #PointNet | |
import einops | |
from .dnnlib_util import ScopedTorchProfiler, printarr | |
def generate_plane_features(p, c, resolution, plane='xz'): | |
""" | |
Args: | |
p: (B,3,n_p) | |
c: (B,C,n_p) | |
""" | |
padding = 0. | |
c_dim = c.size(1) | |
# acquire indices of features in plane | |
xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) # normalize to the range of (0, 1) | |
index = coordinate2index(xy, resolution) | |
# scatter plane features from points | |
fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2) | |
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 | |
fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) # sparce matrix (B x 512 x reso x reso) | |
return fea_plane | |
def normalize_coordinate(p, padding=0.1, plane='xz'): | |
''' Normalize coordinate to [0, 1] for unit cube experiments | |
Args: | |
p (tensor): point | |
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] | |
plane (str): plane feature type, ['xz', 'xy', 'yz'] | |
''' | |
if plane == 'xz': | |
xy = p[:, :, [0, 2]] | |
elif plane =='xy': | |
xy = p[:, :, [0, 1]] | |
else: | |
xy = p[:, :, [1, 2]] | |
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) | |
xy_new = xy_new + 0.5 # range (0, 1) | |
# if there are outliers out of the range | |
if xy_new.max() >= 1: | |
xy_new[xy_new >= 1] = 1 - 10e-6 | |
if xy_new.min() < 0: | |
xy_new[xy_new < 0] = 0.0 | |
return xy_new | |
def coordinate2index(x, resolution): | |
''' Normalize coordinate to [0, 1] for unit cube experiments. | |
Corresponds to our 3D model | |
Args: | |
x (tensor): coordinate | |
reso (int): defined resolution | |
coord_type (str): coordinate type | |
''' | |
x = (x * resolution).long() | |
index = x[:, :, 0] + resolution * x[:, :, 1] | |
index = index[:, None, :] | |
return index | |
def softclip(x, min, max, hardness=5): | |
# Soft clipping for the logsigma | |
x = min + F.softplus(hardness*(x - min))/hardness | |
x = max - F.softplus(-hardness*(x - max))/hardness | |
return x | |
def sample_triplane_feat(feature_triplane, normalized_pos): | |
''' | |
normalized_pos [-1, 1] | |
''' | |
tri_plane = torch.unbind(feature_triplane, dim=1) | |
x_feat = F.grid_sample( | |
tri_plane[0], | |
torch.cat( | |
[normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]], | |
dim=-1).unsqueeze(dim=1), padding_mode='border', | |
align_corners=True) | |
y_feat = F.grid_sample( | |
tri_plane[1], | |
torch.cat( | |
[normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]], | |
dim=-1).unsqueeze(dim=1), padding_mode='border', | |
align_corners=True) | |
z_feat = F.grid_sample( | |
tri_plane[2], | |
torch.cat( | |
[normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]], | |
dim=-1).unsqueeze(dim=1), padding_mode='border', | |
align_corners=True) | |
final_feat = (x_feat + y_feat + z_feat) | |
final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) # 32dimension | |
return final_feat | |
# @persistence.persistent_class | |
class TriPlanePC2Encoder(torch.nn.Module): | |
# Encoder that encode point cloud to triplane feature vector similar to ConvOccNet | |
def __init__( | |
self, | |
cfg, | |
device='cuda', | |
shape_min=-1.0, | |
shape_length=2.0, | |
use_2d_feat=False, | |
# point_encoder='pvcnn', | |
# use_point_scatter=False | |
): | |
""" | |
Outputs latent triplane from PC input | |
Configs: | |
max_logsigma: (float) Soft clip upper range for logsigm | |
min_logsigma: (float) | |
point_encoder_type: (str) one of ['pvcnn', 'pointnet'] | |
pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel | |
features (instead of scattering point features) | |
unet_cfg: (dict) | |
z_triplane_channels: (int) output latent triplane | |
z_triplane_resolution: (int) | |
Args: | |
""" | |
# assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 | |
super().__init__() | |
self.device = device | |
self.cfg = cfg | |
self.shape_min = shape_min | |
self.shape_length = shape_length | |
self.z_triplane_resolution = cfg.z_triplane_resolution | |
z_triplane_channels = cfg.z_triplane_channels | |
point_encoder_out_dim = z_triplane_channels #* 2 | |
in_channels = 6 | |
# self.resample_filter=[1, 3, 3, 1] | |
if cfg.point_encoder_type == 'pvcnn': | |
self.pc_encoder = PVCNNEncoder(point_encoder_out_dim, | |
device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) # Encode it to a volume vector. | |
elif cfg.point_encoder_type == 'pointnet': | |
# TODO the pointnet was buggy, investigate | |
self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim, | |
dim=in_channels, hidden_dim=32, | |
plane_resolution=self.z_triplane_resolution, | |
padding=0) | |
else: | |
raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented") | |
if cfg.unet_cfg.enabled: | |
self.unet_encoder = setup_unet( | |
output_channels=point_encoder_out_dim, | |
input_channels=point_encoder_out_dim, | |
unet_cfg=cfg.unet_cfg) | |
else: | |
self.unet_encoder = None | |
# @ScopedTorchProfiler('encode') | |
def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict: | |
# output = AttrDict() | |
point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length # [0, 1] | |
point_cloud_xyz = point_cloud_xyz - 0.5 # [-0.5, 0.5] | |
point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1) | |
if self.cfg.point_encoder_type == 'pvcnn': | |
if mv_feat is not None: | |
pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx) | |
else: | |
pc_feat, points_feat = self.pc_encoder(point_cloud) # 3D feature volume: BxDx32x32x32 | |
if self.cfg.use_point_scatter: | |
# Scattering from PVCNN point features | |
points_feat_ = points_feat[0] | |
# shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64) | |
pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_, | |
resolution=self.z_triplane_resolution, plane='xy') | |
pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_, | |
resolution=self.z_triplane_resolution, plane='yz') | |
pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_, | |
resolution=self.z_triplane_resolution, plane='xz') | |
pc_feat = pc_feat[0] | |
else: | |
pc_feat = pc_feat[0] | |
sf = self.z_triplane_resolution//32 # 32 is PVCNN's voxel dim | |
pc_feat_1 = torch.mean(pc_feat, dim=-1) #xy_plane, normalize in z plane | |
pc_feat_2 = torch.mean(pc_feat, dim=-3) #yz_plane, normalize in x plane | |
pc_feat_3 = torch.mean(pc_feat, dim=-2) #xz_plane, normalize in y plane | |
# nearest upsample | |
pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf) | |
pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) | |
pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) | |
elif self.cfg.point_encoder_type == 'pointnet': | |
assert self.cfg.use_point_scatter | |
# Run ConvPointnet | |
pc_feat = self.pc_encoder(point_cloud) | |
pc_feat_1 = pc_feat['xy'] # | |
pc_feat_2 = pc_feat['yz'] | |
pc_feat_3 = pc_feat['xz'] | |
else: | |
raise NotImplementedError() | |
if self.unet_encoder is not None: | |
# TODO eval adding a skip connection | |
# Unet expects B, 3, C, H, W | |
pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) | |
# dpc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre) | |
# pc_feat_tri_plane_stack = pc_feat_tri_plane_stack_pre + dpc_feat_tri_plane_stack | |
pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre) | |
pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1) | |
return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) | |
def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None): | |
return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx) |