Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 | |
from ...modules import sparse as sp | |
from .base import SparseTransformerBase | |
class SparseSubdivideBlock3d(nn.Module): | |
def __init__( | |
self, | |
channels: int, | |
out_channels: Optional[int] = None, | |
use_checkpoint: bool = False, | |
): | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_checkpoint = use_checkpoint | |
self.act_layers = nn.Sequential( | |
sp.SparseConv3d(channels, self.out_channels, 3, padding=1), | |
sp.SparseSiLU() | |
) | |
self.sub = sp.SparseSubdivide() | |
self.out_layers = nn.Sequential( | |
sp.SparseConv3d(self.out_channels, self.out_channels, 3, padding=1), | |
sp.SparseSiLU(), | |
) | |
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
h = self.act_layers(x) | |
h = self.sub(h) | |
h = self.out_layers(h) | |
return h | |
def forward(self, x: torch.Tensor): | |
if self.use_checkpoint: | |
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) | |
else: | |
return self._forward(x) | |
class SparseSDFDecoder(SparseTransformerBase): | |
def __init__( | |
self, | |
resolution: int, | |
model_channels: int, | |
latent_channels: int, | |
num_blocks: int, | |
num_heads: Optional[int] = None, | |
num_head_channels: Optional[int] = 64, | |
mlp_ratio: float = 4, | |
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", | |
window_size: int = 8, | |
pe_mode: Literal["ape", "rope"] = "ape", | |
use_fp16: bool = False, | |
use_checkpoint: bool = False, | |
qk_rms_norm: bool = False, | |
representation_config: dict = None, | |
out_channels: int = 1, | |
chunk_size: int = 1, | |
): | |
super().__init__( | |
in_channels=latent_channels, | |
model_channels=model_channels, | |
num_blocks=num_blocks, | |
num_heads=num_heads, | |
num_head_channels=num_head_channels, | |
mlp_ratio=mlp_ratio, | |
attn_mode=attn_mode, | |
window_size=window_size, | |
pe_mode=pe_mode, | |
use_fp16=use_fp16, | |
use_checkpoint=use_checkpoint, | |
qk_rms_norm=qk_rms_norm, | |
) | |
self.resolution = resolution | |
self.rep_config = representation_config | |
self.out_channels = out_channels | |
self.chunk_size = chunk_size | |
self.upsample = nn.ModuleList([ | |
SparseSubdivideBlock3d( | |
channels=model_channels, | |
out_channels=model_channels // 4, | |
use_checkpoint=use_checkpoint, | |
), | |
SparseSubdivideBlock3d( | |
channels=model_channels // 4, | |
out_channels=model_channels // 8, | |
use_checkpoint=use_checkpoint, | |
), | |
SparseSubdivideBlock3d( | |
channels=model_channels // 8, | |
out_channels=model_channels // 16, | |
use_checkpoint=use_checkpoint, | |
) | |
]) | |
self.out_layer = sp.SparseLinear(model_channels // 16, self.out_channels) | |
self.out_active = sp.SparseTanh() | |
self.initialize_weights() | |
if use_fp16: | |
self.convert_to_fp16() | |
def initialize_weights(self) -> None: | |
super().initialize_weights() | |
# Zero-out output layers: | |
nn.init.constant_(self.out_layer.weight, 0) | |
nn.init.constant_(self.out_layer.bias, 0) | |
def convert_to_fp16(self) -> None: | |
""" | |
Convert the torso of the model to float16. | |
""" | |
super().convert_to_fp16() | |
self.upsample.apply(convert_module_to_f16) | |
def convert_to_fp32(self) -> None: | |
""" | |
Convert the torso of the model to float32. | |
""" | |
super().convert_to_fp32() | |
self.upsample.apply(convert_module_to_f32) | |
def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4): | |
sub_resolution = self.resolution // chunk_size | |
upsample_ratio = 8 # hard-coded here | |
assert sub_resolution % padding == 0 | |
out = [] | |
for i in range(chunk_size): | |
for j in range(chunk_size): | |
for k in range(chunk_size): | |
# Calculate padded boundaries | |
start_x = max(0, i * sub_resolution - padding) | |
end_x = min((i + 1) * sub_resolution + padding, self.resolution) | |
start_y = max(0, j * sub_resolution - padding) | |
end_y = min((j + 1) * sub_resolution + padding, self.resolution) | |
start_z = max(0, k * sub_resolution - padding) | |
end_z = min((k + 1) * sub_resolution + padding, self.resolution) | |
# Store original (unpadded) boundaries for later cropping | |
orig_start_x = i * sub_resolution | |
orig_end_x = (i + 1) * sub_resolution | |
orig_start_y = j * sub_resolution | |
orig_end_y = (j + 1) * sub_resolution | |
orig_start_z = k * sub_resolution | |
orig_end_z = (k + 1) * sub_resolution | |
mask = torch.logical_and( | |
torch.logical_and( | |
torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x), | |
torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y) | |
), | |
torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z) | |
) | |
if mask.sum() > 0: | |
# Get the coordinates and shift them to local space | |
coords = x.coords[mask].clone() | |
# Shift to local coordinates | |
coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z], | |
device=coords.device).view(1, 3) | |
chunk_tensor = sp.SparseTensor(x.feats[mask], coords) | |
# Store the boundaries and offsets as metadata for later reconstruction | |
chunk_tensor.bounds = { | |
'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)), | |
'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction | |
} | |
out.append(chunk_tensor) | |
del mask | |
torch.cuda.empty_cache() | |
return out | |
def split_single_chunk(self, x: sp.SparseTensor, chunk_size=4, padding=4): | |
sub_resolution = self.resolution // chunk_size | |
upsample_ratio = 8 # hard-coded here | |
assert sub_resolution % padding == 0 | |
mask_sum = -1 | |
while mask_sum < 1: | |
orig_start_x = random.randint(0, self.resolution - sub_resolution) | |
orig_end_x = orig_start_x + sub_resolution | |
orig_start_y = random.randint(0, self.resolution - sub_resolution) | |
orig_end_y = orig_start_y + sub_resolution | |
orig_start_z = random.randint(0, self.resolution - sub_resolution) | |
orig_end_z = orig_start_z + sub_resolution | |
start_x = max(0, orig_start_x - padding) | |
end_x = min(orig_end_x + padding, self.resolution) | |
start_y = max(0, orig_start_y - padding) | |
end_y = min(orig_end_y + padding, self.resolution) | |
start_z = max(0, orig_start_z - padding) | |
end_z = min(orig_end_z + padding, self.resolution) | |
mask_ori = torch.logical_and( | |
torch.logical_and( | |
torch.logical_and(x.coords[:, 1] >= orig_start_x, x.coords[:, 1] < orig_end_x), | |
torch.logical_and(x.coords[:, 2] >= orig_start_y, x.coords[:, 2] < orig_end_y) | |
), | |
torch.logical_and(x.coords[:, 3] >= orig_start_z, x.coords[:, 3] < orig_end_z) | |
) | |
mask_sum = mask_ori.sum() | |
# Store the boundaries and offsets as metadata for later reconstruction | |
bounds = { | |
'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)), | |
'start': (start_x, end_x, start_y, end_y, start_z, end_z), | |
'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction | |
} | |
return bounds | |
def forward_single_chunk(self, x: sp.SparseTensor, padding=4): | |
bounds = self.split_single_chunk(x, self.chunk_size, padding=padding) | |
start_x, end_x, start_y, end_y, start_z, end_z = bounds['start'] | |
mask = torch.logical_and( | |
torch.logical_and( | |
torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x), | |
torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y) | |
), | |
torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z) | |
) | |
# Shift to local coordinates | |
coords = x.coords.clone() | |
coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z], | |
device=coords.device).view(1, 3) | |
chunk = sp.SparseTensor(x.feats[mask], coords[mask]) | |
chunk_result = self.upsamples(chunk) | |
coords = chunk_result.coords.clone() | |
# Restore global coordinates | |
offsets = torch.tensor(bounds['offsets'], | |
device=coords.device).view(1, 3) | |
coords[:, 1:] = coords[:, 1:] + offsets | |
# Filter points within original bounds | |
original = bounds['original'] | |
within_bounds = torch.logical_and( | |
torch.logical_and( | |
torch.logical_and( | |
coords[:, 1] >= original[0], | |
coords[:, 1] < original[1] | |
), | |
torch.logical_and( | |
coords[:, 2] >= original[2], | |
coords[:, 2] < original[3] | |
) | |
), | |
torch.logical_and( | |
coords[:, 3] >= original[4], | |
coords[:, 3] < original[5] | |
) | |
) | |
final_coords = coords[within_bounds] | |
final_feats = chunk_result.feats[within_bounds] | |
return sp.SparseTensor(final_feats, final_coords) | |
def upsamples(self, x, return_feat: bool = False): | |
dtype = x.dtype | |
for block in self.upsample: | |
x = block(x) | |
x = x.type(dtype) | |
output = self.out_active(self.out_layer(x)) | |
if return_feat: | |
return output, x | |
else: | |
return output | |
def forward(self, x: sp.SparseTensor, factor: float = None, return_feat: bool = False): | |
h = super().forward(x, factor) | |
if self.chunk_size <= 1: | |
for block in self.upsample: | |
h = block(h) | |
h = h.type(x.dtype) | |
if return_feat: | |
return self.out_active(self.out_layer(h)), h | |
h = self.out_layer(h) | |
h = self.out_active(h) | |
return h | |
else: | |
if self.training: | |
return self.forward_single_chunk(h) | |
else: | |
batch_size = x.shape[0] | |
chunks = self.split_for_meshing(h, chunk_size=self.chunk_size) | |
all_coords, all_feats = [], [] | |
for chunk_idx, chunk in enumerate(chunks): | |
chunk_result = self.upsamples(chunk) | |
for b in range(batch_size): | |
mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1) | |
if mask.numel() > 0: | |
coords = chunk_result.coords[mask].clone() | |
# Restore global coordinates | |
offsets = torch.tensor(chunk.bounds['offsets'], | |
device=coords.device).view(1, 3) | |
coords[:, 1:] = coords[:, 1:] + offsets | |
# Filter points within original bounds | |
bounds = chunk.bounds['original'] | |
within_bounds = torch.logical_and( | |
torch.logical_and( | |
torch.logical_and( | |
coords[:, 1] >= bounds[0], | |
coords[:, 1] < bounds[1] | |
), | |
torch.logical_and( | |
coords[:, 2] >= bounds[2], | |
coords[:, 2] < bounds[3] | |
) | |
), | |
torch.logical_and( | |
coords[:, 3] >= bounds[4], | |
coords[:, 3] < bounds[5] | |
) | |
) | |
if within_bounds.any(): | |
all_coords.append(coords[within_bounds]) | |
all_feats.append(chunk_result.feats[mask][within_bounds]) | |
if not self.training: | |
torch.cuda.empty_cache() | |
final_coords = torch.cat(all_coords) | |
final_feats = torch.cat(all_feats) | |
return sp.SparseTensor(final_feats, final_coords) | |