test / direct3d_s2 /models /refiner /unet_refiner.py
wushuang98's picture
Upload 197 files
bcb05d1 verified
# -*- coding: utf-8 -*-
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet3d import UNet3DModel
import trimesh
from tqdm import tqdm
from skimage import measure
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
def adaptive_conv(inputs,weights):
padding = (1, 1, 1, 1, 1, 1)
padded_input = F.pad(inputs, padding, mode="constant", value=0)
output = torch.zeros_like(inputs)
size=inputs.shape[-1]
for i in range(3):
for j in range(3):
for k in range(3):
output=output+padded_input[:,:,i:i+size,j:j+size,k:k+size]*weights[:,i*9+j*3+k:i*9+j*3+k+1]
return output
def adaptive_block(inputs,conv,weights_=None):
if weights_ != None:
weights = conv(weights_)
else:
weights = conv(inputs)
weights = F.normalize(weights, dim=1, p=1)
for i in range(3):
inputs = adaptive_conv(inputs, weights)
return inputs
class GeoDecoder(nn.Module):
def __init__(self,
n_features: int,
hidden_dim: int = 32,
num_layers: int = 4,
use_sdf: bool = False,
activation: nn.Module = nn.ReLU):
super().__init__()
self.use_sdf=use_sdf
self.net = nn.Sequential(
nn.Linear(n_features, hidden_dim),
activation(),
*itertools.chain(*[[
nn.Linear(hidden_dim, hidden_dim),
activation(),
] for _ in range(num_layers - 2)]),
nn.Linear(hidden_dim, 8),
)
# init all bias to zero
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.net(x)
return x
class Voxel_RefinerXL(nn.Module):
def __init__(self,
in_channels: int = 1,
out_channels: int = 1,
layers_per_block: int = 2,
layers_mid_block: int = 2,
patch_size: int = 192,
res: int = 512,
use_checkpoint: bool=False,
use_fp16: bool = False):
super().__init__()
self.unet3d1 = UNet3DModel(in_channels=16, out_channels=8, use_conv_out=False,
layers_per_block=layers_per_block, layers_mid_block=layers_mid_block,
block_out_channels=(8, 32, 128,512), norm_num_groups=4, use_checkpoint=use_checkpoint)
self.conv_in = nn.Conv3d(in_channels, 8, kernel_size=3, padding=1)
self.latent_mlp = GeoDecoder(32)
self.adaptive_conv1 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False))
self.adaptive_conv2 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False))
self.adaptive_conv3 = nn.Sequential(nn.Conv3d(8, 8, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv3d(8, 27, kernel_size=3, padding=1, bias=False))
self.mid_conv = nn.Conv3d(8, 8, kernel_size=3, padding=1)
self.conv_out = nn.Conv3d(8, out_channels, kernel_size=3, padding=1)
self.patch_size = patch_size
self.res = res
self.use_fp16 = use_fp16
self.dtype = torch.float16 if use_fp16 else torch.float32
if use_fp16:
self.convert_to_fp16()
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
# self.blocks.apply(convert_module_to_f16)
self.apply(convert_module_to_f16)
def run(self,
reconst_x,
feat,
mc_threshold=0,
):
batch_size = int(reconst_x.coords[..., 0].max()) + 1
sparse_sdf, sparse_index = reconst_x.feats, reconst_x.coords
sparse_feat = feat.feats
device = sparse_sdf.device
dtype = sparse_sdf.dtype
res = self.res
sdfs = []
for i in range(batch_size):
idx = sparse_index[..., 0] == i
sparse_sdf_i, sparse_index_i = sparse_sdf[idx].squeeze(-1), sparse_index[idx][..., 1:]
sdf = torch.ones((res, res, res)).to(device).to(dtype)
sdf[sparse_index_i[..., 0], sparse_index_i[..., 1], sparse_index_i[..., 2]] = sparse_sdf_i
sdfs.append(sdf.unsqueeze(0))
sdfs = torch.stack(sdfs, dim=0)
feats = torch.zeros((batch_size, sparse_feat.shape[-1], res, res, res),
device=device, dtype=dtype)
feats[sparse_index[...,0],:,sparse_index[...,1],sparse_index[...,2],sparse_index[...,3]] = sparse_feat
N = sdfs.shape[0]
outputs = torch.ones([N,1,res,res,res], dtype=dtype, device=device)
stride = 160
patch_size = self.patch_size
step = 3
sdfs = sdfs.to(dtype)
feats = feats.to(dtype)
patchs=[]
for i in range(step):
for j in range(step):
for k in tqdm(range(step)):
sdf = sdfs[:, :, stride * i: stride * i + patch_size,
stride * j: stride * j + patch_size,
stride * k: stride * k + patch_size]
crop_feats = feats[:, :, stride * i: stride * i + patch_size,
stride * j: stride * j + patch_size,
stride * k: stride * k + patch_size]
inputs = self.conv_in(sdf)
crop_feats = self.latent_mlp(crop_feats.permute(0,2,3,4,1)).permute(0,4,1,2,3)
inputs = torch.cat([inputs, crop_feats],dim=1)
mid_feat = self.unet3d1(inputs)
mid_feat = adaptive_block(mid_feat, self.adaptive_conv1)
mid_feat = self.mid_conv(mid_feat)
mid_feat = adaptive_block(mid_feat, self.adaptive_conv2)
final_feat = self.conv_out(mid_feat)
final_feat = adaptive_block(final_feat, self.adaptive_conv3, weights_=mid_feat)
output = F.tanh(final_feat)
patchs.append(output)
weights = torch.linspace(0, 1, steps=32, device=device, dtype=dtype)
lines=[]
for i in range(9):
out1 = patchs[i * 3]
out2 = patchs[i * 3 + 1]
out3 = patchs[i * 3 + 2]
line = torch.ones([N, 1, 192, 192,res], dtype=dtype, device=device) * 2
line[:, :, :, :, :160] = out1[:, :, :, :, :160]
line[:, :, :, :, 192:320] = out2[:, :, :, :, 32:160]
line[:, :, :, :, 352:] = out3[:, :, :, :, 32:]
line[:,:,:,:,160:192] = out1[:,:,:,:,160:] * (1-weights.reshape(1,1,1,1,-1)) + out2[:,:,:,:,:32] * weights.reshape(1,1,1,1,-1)
line[:,:,:,:,320:352] = out2[:,:,:,:,160:] * (1-weights.reshape(1,1,1,1,-1)) + out3[:,:,:,:,:32] * weights.reshape(1,1,1,1,-1)
lines.append(line)
layers=[]
for i in range(3):
line1 = lines[i*3]
line2 = lines[i*3+1]
line3 = lines[i*3+2]
layer = torch.ones([N,1,192,res,res], device=device, dtype=dtype) * 2
layer[:,:,:,:160] = line1[:,:,:,:160]
layer[:,:,:,192:320] = line2[:,:,:,32:160]
layer[:,:,:,352:] = line3[:,:,:,32:]
layer[:,:,:,160:192] = line1[:,:,:,160:]*(1-weights.reshape(1,1,1,-1,1))+line2[:,:,:,:32]*weights.reshape(1,1,1,-1,1)
layer[:,:,:,320:352] = line2[:,:,:,160:]*(1-weights.reshape(1,1,1,-1,1))+line3[:,:,:,:32]*weights.reshape(1,1,1,-1,1)
layers.append(layer)
outputs[:,:,:160] = layers[0][:,:,:160]
outputs[:,:,192:320] = layers[1][:,:,32:160]
outputs[:,:,352:] = layers[2][:,:,32:]
outputs[:,:,160:192] = layers[0][:,:,160:]*(1-weights.reshape(1,1,-1,1,1))+layers[1][:,:,:32]*weights.reshape(1,1,-1,1,1)
outputs[:,:,320:352] = layers[1][:,:,160:]*(1-weights.reshape(1,1,-1,1,1))+layers[2][:,:,:32]*weights.reshape(1,1,-1,1,1)
# outputs = -outputs
meshes = []
for i in range(outputs.shape[0]):
vertices, faces, _, _ = measure.marching_cubes(outputs[i, 0].cpu().numpy(), level=mc_threshold, method='lewiner')
vertices = vertices / res * 2 - 1
meshes.append(trimesh.Trimesh(vertices, faces))
return meshes