Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
Refactor generate3d function in inference.py to enhance 3D mesh generation. Introduced new image processing steps for color and XYZ data, improved triplane feature extraction, and updated mesh export logic to utilize temporary file handling. This streamlines the rendering process and maintains compatibility with GPU devices.
d72a5f9
raw
history blame
2.75 kB
import numpy as np
import torch
import time
import nvdiffrast.torch as dr
from util.utils import get_tri
import tempfile
from util.renderer import Renderer
import os
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type=model.geo_type)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(color):
color_list = []
color_list.append(color[:, :, 256 * 5:256 * (1 + 5)])
for i in range(0, 5):
color_list.append(color[:, :, 256 * i:256 * (1 + i)])
return torch.stack(color_list, dim=0)
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = 20
tnew = torch.randint(tnew, tnew + 1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
triplane_feature2 = model.unet2(triplane)
with torch.no_grad():
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
from kiui.mesh_utils import clean_mesh
verts, faces = clean_mesh(
data_config['verts'].squeeze().cuda().numpy().astype(np.float32),
data_config['faces'].squeeze().cuda().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
with torch.no_grad():
mesh_path_base = tempfile.NamedTemporaryFile(suffix="", delete=False).name
# Export mesh with UV, texture, and MTL
ctx = dr.RasterizeCudaContext(device=device)
model.export_mesh_wt_uv(ctx, data_config, mesh_path_base, ind=0, device=device, res=(1024, 1024), tri_fea_2=triplane_feature2)
return mesh_path_base + ".obj"