DetailGen3D / app.py
Aluren's picture
Update app.py
e470c5b verified
raw
history blame
6.22 kB
import os
import random
import tempfile
from typing import Any, List
import spaces
import gradio as gr
import numpy as np
import torch
from gradio_litmodel3d import LitModel3D
from huggingface_hub import snapshot_download
from PIL import Image
import trimesh
from skimage import measure
from detailgen3d.pipelines.pipeline_detailgen3d import DetailGen3DPipeline
from detailgen3d.inference_utils import generate_dense_grid_points
# Constants
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
DTYPE = torch.bfloat16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REPO_ID = "VAST-AI/DetailGen3D"
MARKDOWN = """
## Generating geometry details guided by reference image with [DetailGen3D](https://detailgen3d.github.io/DetailGen3D/)
1. Upload a detailed image of the frontal view and a coarse model. Then click "Run" to generate the refined result.
2. If satisfied, download the result using the "Download GLB" button.
3. Increase CFG strength for better image consistency.
"""
EXAMPLES = [
[
"assets/image/100.png",
"assets/model/100.glb",
42,
False
]
]
os.makedirs(TMP_DIR, exist_ok=True)
local_dir = "pretrained_weights/DetailGen3D"
snapshot_download(repo_id=REPO_ID, local_dir=local_dir)
pipeline = DetailGen3DPipeline.from_pretrained(local_dir).to(DEVICE, dtype=DTYPE)
def load_mesh(mesh_path, num_pc=20480):
mesh = trimesh.load(mesh_path, force="mesh")
center = mesh.bounding_box.centroid
mesh.apply_translation(-center)
scale = max(mesh.bounding_box.extents)
mesh.apply_scale(1.9 / scale)
surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000)
normal = mesh.face_normals[face_indices]
rng = np.random.default_rng()
ind = rng.choice(surface.shape[0], num_pc, replace=False)
surface = torch.FloatTensor(surface[ind])
normal = torch.FloatTensor(normal[ind])
return torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
@torch.no_grad()
@torch.autocast(device_type=DEVICE)
def run_detailgen3d(pipeline, image, mesh, seed, num_inference_steps, guidance_scale):
surface = load_mesh(mesh)
batch_size = 1
# Grid generation
box_min = np.array([-1.005, -1.005, -1.005])
box_max = np.array([1.005, 1.005, 1.005])
sampled_points, grid_size, bbox_size = generate_dense_grid_points(
bbox_min=box_min, bbox_max=box_max, octree_depth=8, indexing="ij"
)
sampled_points = torch.FloatTensor(sampled_points).to(DEVICE, dtype=DTYPE)
sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)
# Pipeline execution
sample = pipeline.vae.encode(surface).latent_dist.sample()
occ = pipeline(
image,
latents=sample,
sampled_points=sampled_points,
guidance_scale=guidance_scale,
noise_aug_level=0,
num_inference_steps=num_inference_steps
).samples[0]
# Mesh processing
grid_logits = occ.view(grid_size).cpu().numpy()
vertices, faces, normals, _ = measure.marching_cubes(grid_logits, 0, method="lewiner")
vertices = vertices / grid_size * bbox_size + box_min
return trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
@spaces.GPU(duration=180)
def run_refinement(
image_path: str,
mesh_path: str,
seed: int,
randomize_seed: bool = False,
num_inference_steps: int = 50,
guidance_scale: float = 4.0,
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
try:
# Validate inputs
if not os.path.exists(image_path):
raise ValueError(f"Image path {image_path} not found")
if not os.path.exists(mesh_path):
raise ValueError(f"Mesh path {mesh_path} not found")
image = Image.open(image_path).convert("RGB")
scene = run_detailgen3d(
pipeline,
image,
mesh_path,
seed,
num_inference_steps,
guidance_scale,
)
# Save temporary result
_, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="detailgen3d_", dir=TMP_DIR)
scene.export(tmp_path)
return tmp_path, tmp_path, seed
finally:
torch.cuda.empty_cache()
# Demo interface
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
with gr.Row():
image_input = gr.Image(
label="Reference Image",
type="filepath",
sources=["upload", "clipboard"],
)
mesh_input = gr.Model3D(
label="Input Model",
camera_position=(90, 90, 3)
)
with gr.Accordion("Advanced Settings", open=False):
seed_input = gr.Slider(0, MAX_SEED, value=0, label="Seed")
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
steps_input = gr.Slider(1, 100, value=50, step=1, label="Inference Steps")
cfg_scale = gr.Slider(0.0, 20.0, value=4.0, step=0.1, label="CFG Scale")
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
model_output = LitModel3D(
label="Result Preview",
height=500,
camera_position=(90, 90, 3)
)
download_btn = gr.DownloadButton(
"Download GLB",
file_count="multiple",
interactive=False
)
# Examples section
gr.Examples(
examples=EXAMPLES,
inputs=[image_input, mesh_input, seed_input, randomize_seed],
outputs=[model_output, download_btn, seed_input],
fn=run_refinement,
cache_examples=False,
label="Example Inputs"
)
# Event handling
run_btn.click(
run_refinement,
inputs=[image_input, mesh_input, seed_input, randomize_seed, steps_input, cfg_scale],
outputs=[model_output, download_btn, seed_input]
).then(
lambda: gr.DownloadButton(interactive=True),
outputs=[download_btn]
)
demo.launch()