Refactor generate3d function in inference.py to enhance 3D mesh generation. Introduced new image processing steps for color and XYZ data, improved triplane feature extraction, and updated mesh export logic to utilize temporary file handling. This streamlines the rendering process and maintains compatibility with GPU devices.
d72a5f9
import numpy as np | |
import torch | |
import time | |
import nvdiffrast.torch as dr | |
from util.utils import get_tri | |
import tempfile | |
from util.renderer import Renderer | |
import os | |
def generate3d(model, rgb, ccm, device): | |
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num, | |
scale=model.input.scale, geo_type=model.geo_type) | |
color_tri = torch.from_numpy(rgb) / 255 | |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255 | |
color = color_tri.permute(2, 0, 1) | |
xyz = xyz_tri.permute(2, 0, 1) | |
def get_imgs(color): | |
color_list = [] | |
color_list.append(color[:, :, 256 * 5:256 * (1 + 5)]) | |
for i in range(0, 5): | |
color_list.append(color[:, :, 256 * i:256 * (1 + i)]) | |
return torch.stack(color_list, dim=0) | |
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device) | |
color = get_imgs(color) | |
xyz = get_imgs(xyz) | |
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0) | |
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0) | |
triplane = torch.cat([color, xyz], dim=1).to(device) | |
model.eval() | |
if model.denoising: | |
tnew = 20 | |
tnew = torch.randint(tnew, tnew + 1, [triplane.shape[0]], dtype=torch.long, device=triplane.device) | |
noise_new = torch.randn_like(triplane) * 0.5 + 0.5 | |
triplane = model.scheduler.add_noise(triplane, noise_new, tnew) | |
with torch.no_grad(): | |
triplane_feature2 = model.unet2(triplane, tnew) | |
else: | |
triplane_feature2 = model.unet2(triplane) | |
with torch.no_grad(): | |
data_config = { | |
'resolution': [1024, 1024], | |
"triview_color": triplane_color.to(device), | |
} | |
verts, faces = model.decode(data_config, triplane_feature2) | |
data_config['verts'] = verts[0] | |
data_config['faces'] = faces | |
from kiui.mesh_utils import clean_mesh | |
verts, faces = clean_mesh( | |
data_config['verts'].squeeze().cuda().numpy().astype(np.float32), | |
data_config['faces'].squeeze().cuda().numpy().astype(np.int32), | |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1 | |
) | |
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous() | |
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous() | |
with torch.no_grad(): | |
mesh_path_base = tempfile.NamedTemporaryFile(suffix="", delete=False).name | |
# Export mesh with UV, texture, and MTL | |
ctx = dr.RasterizeCudaContext(device=device) | |
model.export_mesh_wt_uv(ctx, data_config, mesh_path_base, ind=0, device=device, res=(1024, 1024), tri_fea_2=triplane_feature2) | |
return mesh_path_base + ".obj" |