wushuang98's picture
Upload 197 files
bcb05d1 verified
from typing import *
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
from ...modules import sparse as sp
from .base import SparseTransformerBase
class SparseSubdivideBlock3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.act_layers = nn.Sequential(
sp.SparseConv3d(channels, self.out_channels, 3, padding=1),
sp.SparseSiLU()
)
self.sub = sp.SparseSubdivide()
self.out_layers = nn.Sequential(
sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1),
sp.SparseSiLU(),
)
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = self.act_layers(x)
h = self.sub(h)
h = self.out_layers(h)
return h
def forward(self, x: torch.Tensor):
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseSDFDecoder(SparseTransformerBase):
def __init__(
self,
resolution: int,
model_channels: int,
latent_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
representation_config: dict = None,
out_channels: int = 1,
chunk_size: int = 1,
):
super().__init__(
in_channels=latent_channels,
model_channels=model_channels,
num_blocks=num_blocks,
num_heads=num_heads,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
)
self.resolution = resolution
self.rep_config = representation_config
self.out_channels = out_channels
self.chunk_size = chunk_size
self.upsample = nn.ModuleList([
SparseSubdivideBlock3d(
channels=model_channels,
out_channels=model_channels // 4,
use_checkpoint=use_checkpoint,
),
SparseSubdivideBlock3d(
channels=model_channels // 4,
out_channels=model_channels // 8,
use_checkpoint=use_checkpoint,
),
SparseSubdivideBlock3d(
channels=model_channels // 8,
out_channels=model_channels // 16,
use_checkpoint=use_checkpoint,
)
])
self.out_layer = sp.SparseLinear(model_channels // 16, self.out_channels)
self.out_active = sp.SparseTanh()
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
def initialize_weights(self) -> None:
super().initialize_weights()
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
super().convert_to_fp16()
self.upsample.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
super().convert_to_fp32()
self.upsample.apply(convert_module_to_f32)
@torch.no_grad()
def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4):
sub_resolution = self.resolution // chunk_size
upsample_ratio = 8 # hard-coded here
assert sub_resolution % padding == 0
out = []
for i in range(chunk_size):
for j in range(chunk_size):
for k in range(chunk_size):
# Calculate padded boundaries
start_x = max(0, i * sub_resolution - padding)
end_x = min((i + 1) * sub_resolution + padding, self.resolution)
start_y = max(0, j * sub_resolution - padding)
end_y = min((j + 1) * sub_resolution + padding, self.resolution)
start_z = max(0, k * sub_resolution - padding)
end_z = min((k + 1) * sub_resolution + padding, self.resolution)
# Store original (unpadded) boundaries for later cropping
orig_start_x = i * sub_resolution
orig_end_x = (i + 1) * sub_resolution
orig_start_y = j * sub_resolution
orig_end_y = (j + 1) * sub_resolution
orig_start_z = k * sub_resolution
orig_end_z = (k + 1) * sub_resolution
mask = torch.logical_and(
torch.logical_and(
torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x),
torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y)
),
torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z)
)
if mask.sum() > 0:
# Get the coordinates and shift them to local space
coords = x.coords[mask].clone()
# Shift to local coordinates
coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z],
device=coords.device).view(1, 3)
chunk_tensor = sp.SparseTensor(x.feats[mask], coords)
# Store the boundaries and offsets as metadata for later reconstruction
chunk_tensor.bounds = {
'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)),
'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction
}
out.append(chunk_tensor)
del mask
torch.cuda.empty_cache()
return out
@torch.no_grad()
def split_single_chunk(self, x: sp.SparseTensor, chunk_size=4, padding=4):
sub_resolution = self.resolution // chunk_size
upsample_ratio = 8 # hard-coded here
assert sub_resolution % padding == 0
mask_sum = -1
while mask_sum < 1:
orig_start_x = random.randint(0, self.resolution - sub_resolution)
orig_end_x = orig_start_x + sub_resolution
orig_start_y = random.randint(0, self.resolution - sub_resolution)
orig_end_y = orig_start_y + sub_resolution
orig_start_z = random.randint(0, self.resolution - sub_resolution)
orig_end_z = orig_start_z + sub_resolution
start_x = max(0, orig_start_x - padding)
end_x = min(orig_end_x + padding, self.resolution)
start_y = max(0, orig_start_y - padding)
end_y = min(orig_end_y + padding, self.resolution)
start_z = max(0, orig_start_z - padding)
end_z = min(orig_end_z + padding, self.resolution)
mask_ori = torch.logical_and(
torch.logical_and(
torch.logical_and(x.coords[:, 1] >= orig_start_x, x.coords[:, 1] < orig_end_x),
torch.logical_and(x.coords[:, 2] >= orig_start_y, x.coords[:, 2] < orig_end_y)
),
torch.logical_and(x.coords[:, 3] >= orig_start_z, x.coords[:, 3] < orig_end_z)
)
mask_sum = mask_ori.sum()
# Store the boundaries and offsets as metadata for later reconstruction
bounds = {
'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)),
'start': (start_x, end_x, start_y, end_y, start_z, end_z),
'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction
}
return bounds
def forward_single_chunk(self, x: sp.SparseTensor, padding=4):
bounds = self.split_single_chunk(x, self.chunk_size, padding=padding)
start_x, end_x, start_y, end_y, start_z, end_z = bounds['start']
mask = torch.logical_and(
torch.logical_and(
torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x),
torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y)
),
torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z)
)
# Shift to local coordinates
coords = x.coords.clone()
coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z],
device=coords.device).view(1, 3)
chunk = sp.SparseTensor(x.feats[mask], coords[mask])
chunk_result = self.upsamples(chunk)
coords = chunk_result.coords.clone()
# Restore global coordinates
offsets = torch.tensor(bounds['offsets'],
device=coords.device).view(1, 3)
coords[:, 1:] = coords[:, 1:] + offsets
# Filter points within original bounds
original = bounds['original']
within_bounds = torch.logical_and(
torch.logical_and(
torch.logical_and(
coords[:, 1] >= original[0],
coords[:, 1] < original[1]
),
torch.logical_and(
coords[:, 2] >= original[2],
coords[:, 2] < original[3]
)
),
torch.logical_and(
coords[:, 3] >= original[4],
coords[:, 3] < original[5]
)
)
final_coords = coords[within_bounds]
final_feats = chunk_result.feats[within_bounds]
return sp.SparseTensor(final_feats, final_coords)
def upsamples(self, x, return_feat: bool = False):
dtype = x.dtype
for block in self.upsample:
x = block(x)
x = x.type(dtype)
output = self.out_active(self.out_layer(x))
if return_feat:
return output, x
else:
return output
def forward(self, x: sp.SparseTensor, factor: float = None, return_feat: bool = False):
h = super().forward(x, factor)
if self.chunk_size <= 1:
for block in self.upsample:
h = block(h)
h = h.type(x.dtype)
if return_feat:
return self.out_active(self.out_layer(h)), h
h = self.out_layer(h)
h = self.out_active(h)
return h
else:
if self.training:
return self.forward_single_chunk(h)
else:
batch_size = x.shape[0]
chunks = self.split_for_meshing(h, chunk_size=self.chunk_size)
all_coords, all_feats = [], []
for chunk_idx, chunk in enumerate(chunks):
chunk_result = self.upsamples(chunk)
for b in range(batch_size):
mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1)
if mask.numel() > 0:
coords = chunk_result.coords[mask].clone()
# Restore global coordinates
offsets = torch.tensor(chunk.bounds['offsets'],
device=coords.device).view(1, 3)
coords[:, 1:] = coords[:, 1:] + offsets
# Filter points within original bounds
bounds = chunk.bounds['original']
within_bounds = torch.logical_and(
torch.logical_and(
torch.logical_and(
coords[:, 1] >= bounds[0],
coords[:, 1] < bounds[1]
),
torch.logical_and(
coords[:, 2] >= bounds[2],
coords[:, 2] < bounds[3]
)
),
torch.logical_and(
coords[:, 3] >= bounds[4],
coords[:, 3] < bounds[5]
)
)
if within_bounds.any():
all_coords.append(coords[within_bounds])
all_feats.append(chunk_result.feats[mask][within_bounds])
if not self.training:
torch.cuda.empty_cache()
final_coords = torch.cat(all_coords)
final_feats = torch.cat(all_feats)
return sp.SparseTensor(final_feats, final_coords)