dkatz2391's picture
Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state` # as the first return value. This ensures the dictionary is available via the API. # - Modified `extract_glb` to accept `state_dict: dict` as its first argument instead of # relying on the implicit `gr.State` object type when called via API. # - Kept Gradio UI bindings (`outputs=[output_buf, ...]`, `inputs=[output_buf, ...]`) # so the UI continues to function by passing the dictionary through output_buf.
3447081 verified
raw
history blame
13.4 kB
# Version: Add API State Fix (2025-05-04)
# Changes:
# - Modified `text_to_3d` to explicitly return the serializable `state_dict` from `pack_state`
# as the first return value. This ensures the dictionary is available via the API.
# - Modified `extract_glb` to accept `state_dict: dict` as its first argument instead of
# relying on the implicit `gr.State` object type when called via API.
# - Kept Gradio UI bindings (`outputs=[output_buf, ...]`, `inputs=[output_buf, ...]`)
# so the UI continues to function by passing the dictionary through output_buf.
import gradio as gr
import spaces
import os
import shutil
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
from easydict import EasyDict as edict
from trellis.pipelines import TrellisTextTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
import traceback
import sys
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
# Add safety check before removing
if os.path.exists(user_dir):
try:
shutil.rmtree(user_dir)
except OSError as e:
print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
# Ensure tensors are on CPU and converted to numpy before returning the dict
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.detach().cpu().numpy(),
'_features_dc': gs._features_dc.detach().cpu().numpy(),
'_scaling': gs._scaling.detach().cpu().numpy(),
'_rotation': gs._rotation.detach().cpu().numpy(),
'_opacity': gs._opacity.detach().cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.detach().cpu().numpy(),
'faces': mesh.faces.detach().cpu().numpy(),
},
}
def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
# Ensure the device is correctly set when unpacking
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gs = Gaussian(
aabb=state_dict['gaussian']['aabb'],
sh_degree=state_dict['gaussian']['sh_degree'],
mininum_kernel_size=state_dict['gaussian']['mininum_kernel_size'],
scaling_bias=state_dict['gaussian']['scaling_bias'],
opacity_bias=state_dict['gaussian']['opacity_bias'],
scaling_activation=state_dict['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state_dict['gaussian']['_xyz'], device=device)
gs._features_dc = torch.tensor(state_dict['gaussian']['_features_dc'], device=device)
gs._scaling = torch.tensor(state_dict['gaussian']['_scaling'], device=device)
gs._rotation = torch.tensor(state_dict['gaussian']['_rotation'], device=device)
gs._opacity = torch.tensor(state_dict['gaussian']['_opacity'], device=device)
mesh = edict(
vertices=torch.tensor(state_dict['mesh']['vertices'], device=device),
faces=torch.tensor(state_dict['mesh']['faces'], device=device),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@spaces.GPU
def text_to_3d(
prompt: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
req: gr.Request,
) -> Tuple[dict, str]: # <- Changed return annotation for clarity
"""
Convert an text prompt to a 3D model.
Args:
prompt (str): The text prompt.
seed (int): The random seed.
ss_guidance_strength (float): The guidance strength for sparse structure generation.
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
slat_guidance_strength (float): The guidance strength for structured latent generation.
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
Returns:
dict: The *serializable dictionary* representing the state of the generated 3D model. <-- CHANGE
str: The path to the video preview of the 3D model.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
# --- Generation Pipeline ---
outputs = pipeline.run(
prompt,
seed=seed,
formats=["gaussian", "mesh"], # Ensure both are generated
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
# --- Create Serializable State Dictionary --- VITAL CHANGE for API
# Instead of returning the raw state object, return a serializable dictionary
# which can be passed via the API correctly.
state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
# --- Render Video Preview ---
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
video_path = os.path.join(user_dir, 'sample.mp4')
imageio.mimsave(video_path, video, fps=15)
torch.cuda.empty_cache()
# --- Return Serializable Dictionary and Video Path --- VITAL CHANGE for API
return state_dict, video_path
@spaces.GPU(duration=90)
def extract_glb(
state_dict: dict, # <-- VITAL CHANGE: Accept the dictionary directly
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model state dictionary.
Args:
state_dict (dict): The serializable dictionary state of the generated 3D model. <-- CHANGE
mesh_simplify (float): The mesh simplification factor.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file (for Model3D component).
str: The path to the extracted GLB file (for DownloadButton).
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
# --- Unpack state from the dictionary --- VITAL CHANGE for API
gs, mesh = unpack_state(state_dict)
# --- Postprocessing and Export ---
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
glb.export(glb_path)
torch.cuda.empty_cache()
# Return path twice for both Model3D and DownloadButton components
return glb_path, glb_path
@spaces.GPU
def extract_gaussian(state_dict: dict, req: gr.Request) -> Tuple[str, str]: # <-- CHANGE: Accept dict
"""
Extract a Gaussian file from the 3D model state dictionary.
Args:
state_dict (dict): The serializable dictionary state of the generated 3D model. <-- CHANGE
Returns:
str: The path to the extracted Gaussian file (for Model3D component).
str: The path to the extracted Gaussian file (for DownloadButton).
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
# --- Unpack state from the dictionary --- VITAL CHANGE for API
gs, _ = unpack_state(state_dict)
gaussian_path = os.path.join(user_dir, 'sample.ply')
gs.save_ply(gaussian_path)
torch.cuda.empty_cache()
# Return path twice for both Model3D and DownloadButton components
return gaussian_path, gaussian_path
# --- Gradio UI Definition ---
# output_buf = gr.State() # No change needed here, it will now hold the dict
# video_output = gr.Video(...) # No change needed
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
* Type a text prompt and click "Generate" to create a 3D asset.
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
""")
with gr.Row():
with gr.Column():
text_prompt = gr.Textbox(label="Text Prompt", lines=5)
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
generate_btn = gr.Button("Generate")
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
with gr.Row():
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
gr.Markdown("""
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
""")
with gr.Column():
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
with gr.Row():
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
# --- State Buffer ---
# This will now hold the dictionary returned by text_to_3d
output_buf = gr.State()
# --- Handlers ---
demo.load(start_session)
demo.unload(end_session)
# --- Generate Button Click Flow ---
# No changes needed to the structure, but text_to_3d now puts the dictionary into output_buf
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
text_to_3d,
inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
outputs=[output_buf, video_output], # output_buf receives state_dict
).then(
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
video_output.clear(
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
# --- Extract GLB Button Click Flow ---
# The input 'output_buf' now contains the state_dict needed by the modified extract_glb function
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size], # Pass the state_dict via output_buf
outputs=[model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
# --- Extract Gaussian Button Click Flow ---
# The input 'output_buf' now contains the state_dict needed by the modified extract_gaussian function
extract_gs_btn.click(
extract_gaussian,
inputs=[output_buf], # Pass the state_dict via output_buf
outputs=[model_output, download_gs],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_gs],
)
model_output.clear(
lambda: gr.Button(interactive=False), # Should clear both potentially?
outputs=[download_glb, download_gs], # Clear both download buttons
)
# --- Launch the Gradio app ---
if __name__ == "__main__":
# Consider adding error handling for pipeline loading
try:
pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
# Move to GPU if available
if torch.cuda.is_available():
pipeline.cuda()
else:
print("WARNING: CUDA not available, running on CPU (will be very slow).")
print("✅ Trellis pipeline loaded successfully.")
except Exception as e:
print(f"❌ Failed to load Trellis pipeline: {e}", file=sys.stderr)
# Optionally exit if pipeline is critical
# sys.exit(1)
demo.launch()