Refactor generate3d function in inference.py to improve readability and maintainability. Enhanced RGB and coordinate conversion, streamlined noise addition for denoising, and updated mesh export process to utilize trimesh for GLB format, ensuring proper handling of UV textures.
dfab55e
import numpy as np | |
import torch | |
import time | |
import nvdiffrast.torch as dr | |
from util.utils import get_tri | |
import tempfile | |
from mesh import Mesh | |
import zipfile | |
from util.renderer import Renderer | |
import trimesh # Needed for glb export | |
def generate3d(model, rgb, ccm, device): | |
model.renderer = Renderer( | |
tet_grid_size=model.tet_grid_size, | |
camera_angle_num=model.camera_angle_num, | |
scale=model.input.scale, | |
geo_type=model.geo_type | |
) | |
# RGB and coordinate conversion | |
color_tri = torch.from_numpy(rgb) / 255 | |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255 | |
color = color_tri.permute(2, 0, 1) | |
xyz = xyz_tri.permute(2, 0, 1) | |
def get_imgs(color): | |
color_list = [color[:, :, 256 * 5:256 * (1 + 5)]] | |
for i in range(0, 5): | |
color_list.append(color[:, :, 256 * i:256 * (1 + i)]) | |
return torch.stack(color_list, dim=0) | |
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device) | |
color = get_imgs(color) | |
xyz = get_imgs(xyz) | |
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0) | |
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0) | |
triplane = torch.cat([color, xyz], dim=1).to(device) | |
model.eval() | |
if model.denoising: | |
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device) | |
noise_new = torch.randn_like(triplane) * 0.5 + 0.5 | |
triplane = model.scheduler.add_noise(triplane, noise_new, tnew) | |
with torch.no_grad(): | |
triplane_feature2 = model.unet2(triplane, tnew) | |
else: | |
triplane_feature2 = model.unet2(triplane) | |
with torch.no_grad(): | |
data_config = { | |
'resolution': [1024, 1024], | |
'triview_color': triplane_color.to(device), | |
} | |
verts, faces = model.decode(data_config, triplane_feature2) | |
data_config['verts'] = verts[0] | |
data_config['faces'] = faces | |
# Optional mesh cleanup (reduce remesh for speed) | |
from kiui.mesh_utils import clean_mesh | |
verts, faces = clean_mesh( | |
data_config['verts'].squeeze().cpu().numpy().astype(np.float32), | |
data_config['faces'].squeeze().cpu().numpy().astype(np.int32), | |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1 | |
) | |
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous() | |
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous() | |
# Rasterization context | |
glctx = dr.RasterizeGLContext() | |
# Temporary output path | |
mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name | |
# Export OBJ with UV and PNG | |
with torch.no_grad(): | |
model.export_mesh_wt_uv( | |
glctx, data_config, mesh_path_obj, "", device, | |
res=(512, 512), tri_fea_2=triplane_feature2 | |
) | |
# Convert to .glb using trimesh | |
mesh = trimesh.load(mesh_path_obj + ".obj", force='mesh') | |
mesh_path_glb = mesh_path_obj + ".glb" | |
mesh.export(mesh_path_glb, file_type='glb') | |
print(f"✅ Exported GLB with UV texture: {mesh_path_glb}") | |
return mesh_path_glb | |