Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
Refactor generate3d function in inference.py to improve readability and maintainability. Enhanced RGB and coordinate conversion, streamlined noise addition for denoising, and updated mesh export process to utilize trimesh for GLB format, ensuring proper handling of UV textures.
dfab55e
raw
history blame
3.14 kB
import numpy as np
import torch
import time
import nvdiffrast.torch as dr
from util.utils import get_tri
import tempfile
from mesh import Mesh
import zipfile
from util.renderer import Renderer
import trimesh # Needed for glb export
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(
tet_grid_size=model.tet_grid_size,
camera_angle_num=model.camera_angle_num,
scale=model.input.scale,
geo_type=model.geo_type
)
# RGB and coordinate conversion
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(color):
color_list = [color[:, :, 256 * 5:256 * (1 + 5)]]
for i in range(0, 5):
color_list.append(color[:, :, 256 * i:256 * (1 + i)])
return torch.stack(color_list, dim=0)
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
triplane_feature2 = model.unet2(triplane)
with torch.no_grad():
data_config = {
'resolution': [1024, 1024],
'triview_color': triplane_color.to(device),
}
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
# Optional mesh cleanup (reduce remesh for speed)
from kiui.mesh_utils import clean_mesh
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
# Rasterization context
glctx = dr.RasterizeGLContext()
# Temporary output path
mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
# Export OBJ with UV and PNG
with torch.no_grad():
model.export_mesh_wt_uv(
glctx, data_config, mesh_path_obj, "", device,
res=(512, 512), tri_fea_2=triplane_feature2
)
# Convert to .glb using trimesh
mesh = trimesh.load(mesh_path_obj + ".obj", force='mesh')
mesh_path_glb = mesh_path_obj + ".glb"
mesh.export(mesh_path_glb, file_type='glb')
print(f"✅ Exported GLB with UV texture: {mesh_path_glb}")
return mesh_path_glb