Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
import itertools | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .unet3d import UNet3DModel | |
import trimesh | |
from tqdm import tqdm | |
from skimage import measure | |
from ...modules.utils import convert_module_to_f16, convert_module_to_f32 | |
def adaptive_conv(inputs,weights): | |
padding = (1, 1, 1, 1, 1, 1) | |
padded_input = F.pad(inputs, padding, mode="constant", value=0) | |
output = torch.zeros_like(inputs) | |
size=inputs.shape[-1] | |
for i in range(3): | |
for j in range(3): | |
for k in range(3): | |
output=output+padded_input[:,:,i:i+size,j:j+size,k:k+size]*weights[:,i*9+j*3+k:i*9+j*3+k+1] | |
return output | |
def adaptive_block(inputs,conv,weights_=None): | |
if weights_ != None: | |
weights = conv(weights_) | |
else: | |
weights = conv(inputs) | |
weights = F.normalize(weights, dim=1, p=1) | |
for i in range(3): | |
inputs = adaptive_conv(inputs, weights) | |
return inputs | |
class GeoDecoder(nn.Module): | |
def __init__(self, | |
n_features: int, | |
hidden_dim: int = 32, | |
num_layers: int = 4, | |
use_sdf: bool = False, | |
activation: nn.Module = nn.ReLU): | |
super().__init__() | |
self.use_sdf=use_sdf | |
self.net = nn.Sequential( | |
nn.Linear(n_features, hidden_dim), | |
activation(), | |
*itertools.chain(*[[ | |
nn.Linear(hidden_dim, hidden_dim), | |
activation(), | |
] for _ in range(num_layers - 2)]), | |
nn.Linear(hidden_dim, 8), | |
) | |
# init all bias to zero | |
for m in self.modules(): | |
if isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
nn.init.zeros_(m.bias) | |
def forward(self, x): | |
x = self.net(x) | |
return x | |
class Voxel_RefinerXL(nn.Module): | |
def __init__(self, | |
in_channels: int = 1, | |
out_channels: int = 1, | |
layers_per_block: int = 2, | |
layers_mid_block: int = 2, | |
patch_size: int = 192, | |
res: int = 512, | |
use_checkpoint: bool=False, | |
use_fp16: bool = False): | |
super().__init__() | |
self.unet3d1 = UNet3DModel(in_channels=16, out_channels=8, use_conv_out=False, | |
layers_per_block=layers_per_block, layers_mid_block=layers_mid_block, | |
block_out_channels=(8, 32, 128,512), norm_num_groups=4, use_checkpoint=use_checkpoint) | |
self.conv_in = nn.Conv3d(in_channels, 8, kernel_size=3, padding=1) | |
self.latent_mlp = GeoDecoder(32) | |
self.adaptive_conv1 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False)) | |
self.adaptive_conv2 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False)) | |
self.adaptive_conv3 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False)) | |
self.mid_conv = nn.Conv3d(8, 8, kernel_size=3, padding=1) | |
self.conv_out = nn.Conv3d(8, out_channels, kernel_size=3, padding=1) | |
self.patch_size = patch_size | |
self.res = res | |
self.use_fp16 = use_fp16 | |
self.dtype = torch.float16 if use_fp16 else torch.float32 | |
if use_fp16: | |
self.convert_to_fp16() | |
def convert_to_fp16(self) -> None: | |
""" | |
Convert the torso of the model to float16. | |
""" | |
# self.blocks.apply(convert_module_to_f16) | |
self.apply(convert_module_to_f16) | |
def run(self, | |
reconst_x, | |
feat, | |
mc_threshold=0, | |
): | |
batch_size = int(reconst_x.coords[..., 0].max()) + 1 | |
sparse_sdf, sparse_index = reconst_x.feats, reconst_x.coords | |
sparse_feat = feat.feats | |
device = sparse_sdf.device | |
dtype = sparse_sdf.dtype | |
res = self.res | |
sdfs = [] | |
for i in range(batch_size): | |
idx = sparse_index[..., 0] == i | |
sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1), sparse_index[idx][..., 1:] | |
sdf = torch.ones((res, res, res)).to(device).to(dtype) | |
sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i | |
sdfs.append(sdf.unsqueeze(0)) | |
sdfs = torch.stack(sdfs, dim=0) | |
feats = torch.zeros((batch_size, sparse_feat.shape[-1], res, res, res), | |
device=device, dtype=dtype) | |
feats[sparse_index[...,0],:,sparse_index[...,1],sparse_index[...,2],sparse_index[...,3]] = sparse_feat | |
N = sdfs.shape[0] | |
outputs = torch.ones([N,1,res,res,res], dtype=dtype, device=device) | |
stride = 160 | |
patch_size = self.patch_size | |
step = 3 | |
sdfs = sdfs.to(dtype) | |
feats = feats.to(dtype) | |
patchs=[] | |
for i in range(step): | |
for j in range(step): | |
for k in tqdm(range(step)): | |
sdf = sdfs[:, :, stride * i: stride * i + patch_size, | |
stride * j: stride * j + patch_size, | |
stride * k: stride * k + patch_size] | |
crop_feats = feats[:, :, stride * i: stride * i + patch_size, | |
stride * j: stride * j + patch_size, | |
stride * k: stride * k + patch_size] | |
inputs = self.conv_in(sdf) | |
crop_feats = self.latent_mlp(crop_feats.permute(0,2,3,4,1)).permute(0,4,1,2,3) | |
inputs = torch.cat([inputs, crop_feats],dim=1) | |
mid_feat = self.unet3d1(inputs) | |
mid_feat = adaptive_block(mid_feat, self.adaptive_conv1) | |
mid_feat = self.mid_conv(mid_feat) | |
mid_feat = adaptive_block(mid_feat, self.adaptive_conv2) | |
final_feat = self.conv_out(mid_feat) | |
final_feat = adaptive_block(final_feat, self.adaptive_conv3, weights_=mid_feat) | |
output = F.tanh(final_feat) | |
patchs.append(output) | |
weights = torch.linspace(0, 1, steps=32, device=device, dtype=dtype) | |
lines=[] | |
for i in range(9): | |
out1 = patchs[i * 3] | |
out2 = patchs[i * 3 + 1] | |
out3 = patchs[i * 3 + 2] | |
line = torch.ones([N, 1, 192, 192,res], dtype=dtype, device=device) * 2 | |
line[:, :, :, :, :160] = out1[:, :, :, :, :160] | |
line[:, :, :, :, 192:320] = out2[:, :, :, :, 32:160] | |
line[:, :, :, :, 352:] = out3[:, :, :, :, 32:] | |
line[:,:,:,:,160:192] = out1[:,:,:,:,160:] * (1-weights.reshape(1,1,1,1,-1)) + out2[:,:,:,:,:32] * weights.reshape(1,1,1,1,-1) | |
line[:,:,:,:,320:352] = out2[:,:,:,:,160:] * (1-weights.reshape(1,1,1,1,-1)) + out3[:,:,:,:,:32] * weights.reshape(1,1,1,1,-1) | |
lines.append(line) | |
layers=[] | |
for i in range(3): | |
line1 = lines[i*3] | |
line2 = lines[i*3+1] | |
line3 = lines[i*3+2] | |
layer = torch.ones([N,1,192,res,res], device=device, dtype=dtype) * 2 | |
layer[:,:,:,:160] = line1[:,:,:,:160] | |
layer[:,:,:,192:320] = line2[:,:,:,32:160] | |
layer[:,:,:,352:] = line3[:,:,:,32:] | |
layer[:,:,:,160:192] = line1[:,:,:,160:]*(1-weights.reshape(1,1,1,-1,1))+line2[:,:,:,:32]*weights.reshape(1,1,1,-1,1) | |
layer[:,:,:,320:352] = line2[:,:,:,160:]*(1-weights.reshape(1,1,1,-1,1))+line3[:,:,:,:32]*weights.reshape(1,1,1,-1,1) | |
layers.append(layer) | |
outputs[:,:,:160] = layers[0][:,:,:160] | |
outputs[:,:,192:320] = layers[1][:,:,32:160] | |
outputs[:,:,352:] = layers[2][:,:,32:] | |
outputs[:,:,160:192] = layers[0][:,:,160:]*(1-weights.reshape(1,1,-1,1,1))+layers[1][:,:,:32]*weights.reshape(1,1,-1,1,1) | |
outputs[:,:,320:352] = layers[1][:,:,160:]*(1-weights.reshape(1,1,-1,1,1))+layers[2][:,:,:32]*weights.reshape(1,1,-1,1,1) | |
# outputs = -outputs | |
meshes = [] | |
for i in range(outputs.shape[0]): | |
vertices, faces, _, _ = measure.marching_cubes(outputs[i, 0].cpu().numpy(), level=mc_threshold, method='lewiner') | |
vertices = vertices / res * 2 - 1 | |
meshes.append(trimesh.Trimesh(vertices, faces)) | |
return meshes | |