# -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F import trimesh from skimage import measure from ...modules import sparse as sp from .encoder import SparseSDFEncoder from .decoder import SparseSDFDecoder from .distributions import DiagonalGaussianDistribution class SparseSDFVAE(nn.Module): def __init__(self, *, embed_dim: int = 0, resolution: int = 64, model_channels_encoder: int = 512, num_blocks_encoder: int = 4, num_heads_encoder: int = 8, num_head_channels_encoder: int = 64, model_channels_decoder: int = 512, num_blocks_decoder: int = 4, num_heads_decoder: int = 8, num_head_channels_decoder: int = 64, out_channels: int = 1, use_fp16: bool = False, use_checkpoint: bool = False, chunk_size: int = 1, latents_scale: float = 1.0, latents_shift: float = 0.0): super().__init__() self.use_checkpoint = use_checkpoint self.resolution = resolution self.latents_scale = latents_scale self.latents_shift = latents_shift self.encoder = SparseSDFEncoder( resolution=resolution, in_channels=model_channels_encoder, model_channels=model_channels_encoder, latent_channels=embed_dim, num_blocks=num_blocks_encoder, num_heads=num_heads_encoder, num_head_channels=num_head_channels_encoder, use_fp16=use_fp16, use_checkpoint=use_checkpoint, ) self.decoder = SparseSDFDecoder( resolution=resolution, model_channels=model_channels_decoder, latent_channels=embed_dim, num_blocks=num_blocks_decoder, num_heads=num_heads_decoder, num_head_channels=num_head_channels_decoder, out_channels=out_channels, use_fp16=use_fp16, use_checkpoint=use_checkpoint, chunk_size=chunk_size, ) self.embed_dim = embed_dim def forward(self, batch): z, posterior = self.encode(batch) reconst_x = self.decoder(z) outputs = {'reconst_x': reconst_x, 'posterior': posterior} return outputs def encode(self, batch, sample_posterior: bool = True): feat, xyz, batch_idx = batch['sparse_sdf'], batch['sparse_index'], batch['batch_idx'] if feat.ndim == 1: feat = feat.unsqueeze(-1) coords = torch.cat([batch_idx.unsqueeze(-1), xyz], dim=-1).int() x = sp.SparseTensor(feat, coords) h = self.encoder(x, batch.get('factor', None)) posterior = DiagonalGaussianDistribution(h.feats, feat_dim=1) if sample_posterior: z = posterior.sample() else: z = posterior.mode() z = h.replace(z) return z, posterior def decode_mesh(self, latents, voxel_resolution: int = 512, mc_threshold: float = 0.2, return_feat: bool = False, factor: float = 1.0): voxel_resolution = int(voxel_resolution / factor) reconst_x = self.decoder(latents, factor=factor, return_feat=return_feat) if return_feat: return reconst_x outputs = self.sparse2mesh(reconst_x, voxel_resolution=voxel_resolution, mc_threshold=mc_threshold) return outputs def sparse2mesh(self, reconst_x: torch.FloatTensor, voxel_resolution: int = 512, mc_threshold: float = 0.0): sparse_sdf, sparse_index = reconst_x.feats.float(), reconst_x.coords batch_size = int(sparse_index[..., 0].max().cpu().numpy() + 1) meshes = [] for i in range(batch_size): idx = sparse_index[..., 0] == i sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1).cpu(), sparse_index[idx][..., 1:].detach().cpu() sdf = torch.ones((voxel_resolution, voxel_resolution, voxel_resolution)) sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i vertices, faces, _, _ = measure.marching_cubes( sdf.numpy(), mc_threshold, method="lewiner", ) vertices = vertices / voxel_resolution * 2 - 1 meshes.append(trimesh.Trimesh(vertices, faces)) return meshes