Spaces:
Runtime error
Runtime error
import os | |
import random | |
import tempfile | |
from typing import Any, List | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
from gradio_litmodel3d import LitModel3D | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
import trimesh | |
from skimage import measure | |
from detailgen3d.pipelines.pipeline_detailgen3d import DetailGen3DPipeline | |
from detailgen3d.inference_utils import generate_dense_grid_points | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") | |
DTYPE = torch.bfloat16 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
REPO_ID = "VAST-AI/DetailGen3D" | |
MARKDOWN = """ | |
## Generating geometry details guided by reference image with [DetailGen3D](https://detailgen3d.github.io/DetailGen3D/) | |
1. Upload a detailed image of the frontal view and a coarse model. Then click "Run" to generate the refined result. | |
2. If satisfied, download the result using the "Download GLB" button. | |
3. Increase CFG strength for better image consistency. | |
""" | |
EXAMPLES = [ | |
[ | |
"assets/image/100.png", | |
"assets/model/100.glb", | |
42, | |
False | |
] | |
] | |
os.makedirs(TMP_DIR, exist_ok=True) | |
local_dir = "pretrained_weights/DetailGen3D" | |
snapshot_download(repo_id=REPO_ID, local_dir=local_dir) | |
pipeline = DetailGen3DPipeline.from_pretrained(local_dir).to(DEVICE, dtype=DTYPE) | |
def load_mesh(mesh_path, num_pc=20480): | |
mesh = trimesh.load(mesh_path, force="mesh") | |
center = mesh.bounding_box.centroid | |
mesh.apply_translation(-center) | |
scale = max(mesh.bounding_box.extents) | |
mesh.apply_scale(1.9 / scale) | |
surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000) | |
normal = mesh.face_normals[face_indices] | |
rng = np.random.default_rng() | |
ind = rng.choice(surface.shape[0], num_pc, replace=False) | |
surface = torch.FloatTensor(surface[ind]) | |
normal = torch.FloatTensor(normal[ind]) | |
return torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() | |
def run_detailgen3d(pipeline, image, mesh, seed, num_inference_steps, guidance_scale): | |
surface = load_mesh(mesh) | |
batch_size = 1 | |
# Grid generation | |
box_min = np.array([-1.005, -1.005, -1.005]) | |
box_max = np.array([1.005, 1.005, 1.005]) | |
sampled_points, grid_size, bbox_size = generate_dense_grid_points( | |
bbox_min=box_min, bbox_max=box_max, octree_depth=8, indexing="ij" | |
) | |
sampled_points = torch.FloatTensor(sampled_points).to(DEVICE, dtype=DTYPE) | |
sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1) | |
# Pipeline execution | |
sample = pipeline.vae.encode(surface).latent_dist.sample() | |
occ = pipeline( | |
image, | |
latents=sample, | |
sampled_points=sampled_points, | |
guidance_scale=guidance_scale, | |
noise_aug_level=0, | |
num_inference_steps=num_inference_steps | |
).samples[0] | |
# Mesh processing | |
grid_logits = occ.view(grid_size).cpu().numpy() | |
vertices, faces, normals, _ = measure.marching_cubes(grid_logits, 0, method="lewiner") | |
vertices = vertices / grid_size * bbox_size + box_min | |
return trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces)) | |
def run_refinement( | |
image_path: str, | |
mesh_path: str, | |
seed: int, | |
randomize_seed: bool = False, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 4.0, | |
): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
try: | |
# Validate inputs | |
if not os.path.exists(image_path): | |
raise ValueError(f"Image path {image_path} not found") | |
if not os.path.exists(mesh_path): | |
raise ValueError(f"Mesh path {mesh_path} not found") | |
image = Image.open(image_path).convert("RGB") | |
scene = run_detailgen3d( | |
pipeline, | |
image, | |
mesh_path, | |
seed, | |
num_inference_steps, | |
guidance_scale, | |
) | |
# Save temporary result | |
_, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="detailgen3d_", dir=TMP_DIR) | |
scene.export(tmp_path) | |
return tmp_path, tmp_path, seed | |
finally: | |
torch.cuda.empty_cache() | |
# Demo interface | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
image_input = gr.Image( | |
label="Reference Image", | |
type="filepath", | |
sources=["upload", "clipboard"], | |
) | |
mesh_input = gr.Model3D( | |
label="Input Model", | |
camera_position=(90, 90, 3) | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed_input = gr.Slider(0, MAX_SEED, value=0, label="Seed") | |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
steps_input = gr.Slider(1, 100, value=50, step=1, label="Inference Steps") | |
cfg_scale = gr.Slider(0.0, 20.0, value=4.0, step=0.1, label="CFG Scale") | |
run_btn = gr.Button("Generate", variant="primary") | |
with gr.Column(): | |
model_output = LitModel3D( | |
label="Result Preview", | |
height=500, | |
camera_position=(90, 90, 3) | |
) | |
download_btn = gr.DownloadButton( | |
"Download GLB", | |
file_count="multiple", | |
interactive=False | |
) | |
# Examples section | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=[image_input, mesh_input, seed_input, randomize_seed], | |
outputs=[model_output, download_btn, seed_input], | |
fn=run_refinement, | |
cache_examples=False, | |
label="Example Inputs" | |
) | |
# Event handling | |
run_btn.click( | |
run_refinement, | |
inputs=[image_input, mesh_input, seed_input, randomize_seed, steps_input, cfg_scale], | |
outputs=[model_output, download_btn, seed_input] | |
).then( | |
lambda: gr.DownloadButton(interactive=True), | |
outputs=[download_btn] | |
) | |
demo.launch() | |