trellis-3d-api / app.py
Steven18's picture
update
ed07608
import gradio as gr
from gradio.events import SelectData
import spaces
from gradio_litmodel3d import LitModel3D
import json
import os
import shutil
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
from pathlib import Path
from easydict import EasyDict as edict
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
from collections.abc import Sequence
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def preprocess_image(image: Image.Image) -> Image.Image:
"""
Preprocess the input image.
Args:
image (Image.Image): The input image.
Returns:
Image.Image: The preprocessed image.
"""
processed_image = pipeline.preprocess_image(image)
return processed_image
def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
"""
Preprocess a list of input images.
Args:
images (List[Tuple[Image.Image, str]]): The input images.
Returns:
List[Image.Image]: The preprocessed images.
"""
images = [image[0] for image in images]
processed_images = [pipeline.preprocess_image(image) for image in images]
return processed_images
def preprocess_upload_images(file_list: List[Any]) -> List[Tuple[Image.Image, str]]:
images = []
for f in file_list:
if isinstance(f, dict):
path = f.get("path") or f.get("name")
filename = os.path.basename(path)
else: # UploadedFile / FileData
path = f.name
filename = os.path.basename(path)
img = Image.open(path).convert("RGBA").resize(
(518, 518), Image.Resampling.LANCZOS
)
images.append((img, filename))
return images
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def normalize_multiimages(multiimages: Sequence) -> List[Tuple[Image.Image, str]]:
if not multiimages:
return []
if isinstance(multiimages[0], Image.Image):
return [
(pipeline.preprocess_image(img), f"gallery_{i}.png")
for i, img in enumerate(multiimages)
]
if isinstance(multiimages[0], tuple):
return [
(pipeline.preprocess_image(img), name)
for img, name in multiimages
]
return preprocess_upload_images(multiimages)
@spaces.GPU
def image_to_3d(
image: Image.Image,
multiimages: List[Any],
is_multiimage: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
multiimage_algo: Literal["multidiffusion", "stochastic"],
req: gr.Request,
) -> Tuple[dict, str]:
"""
Convert an image (or multiple images) into a 3D model and return its state and video.
Args:
image (Image.Image): The input image for single-image mode.
multiimages (List[Tuple[Image.Image, str]]): List of images with captions for multi-image mode.
is_multiimage (str): Whether to use multi-image generation.
seed (int): Random seed for reproducibility.
ss_guidance_strength (float): Sparse structure guidance strength.
ss_sampling_steps (int): Sparse structure sampling steps.
slat_guidance_strength (float): SLAT guidance strength.
slat_sampling_steps (int): SLAT sampling steps.
multiimage_algo (str): Multi-image algorithm to use.
Returns:
dict: The information of the generated 3D model.
str: The path to the video of the 3D model.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
is_multiimage = is_multiimage.lower() == "true"
multiimages = normalize_multiimages(multiimages)
print("[DEBUG] is_multiimage:", is_multiimage, "num_imgs:", len(multiimages))
if is_multiimage and len(multiimages) == 0:
is_multiimage = False
# Run pipeline depending on mode
if not is_multiimage:
outputs = pipeline.run(
image,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
else:
pil_images = [img for img, _ in multiimages]
assert all(isinstance(im, Image.Image) for im in pil_images)
outputs = pipeline.run_multi_image(
pil_images,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
mode=multiimage_algo,
)
# Render the 3D video combining color and geometry
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
# Save the video
video_path = os.path.join(user_dir, 'sample.mp4')
imageio.mimsave(video_path, video, fps=15)
# Pack state for downstream use
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
torch.cuda.empty_cache()
return state, video_path
@spaces.GPU(duration=90)
def extract_glb(
state: dict,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
mesh_simplify (float): The mesh simplification factor.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, mesh = unpack_state(state)
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
glb.export(glb_path)
torch.cuda.empty_cache()
return glb_path, glb_path
@spaces.GPU
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
"""
Extract a Gaussian file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
Returns:
str: The path to the extracted Gaussian file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, _ = unpack_state(state)
gaussian_path = os.path.join(user_dir, 'sample.ply')
gs.save_ply(gaussian_path)
torch.cuda.empty_cache()
return gaussian_path, gaussian_path
def prepare_multi_example() -> List[Image.Image]:
multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
images = []
for case in multi_case:
_images = []
for i in range(1, 4):
img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
W, H = img.size
img = img.resize((int(W / H * 512), 512))
_images.append(np.array(img))
images.append(Image.fromarray(np.concatenate(_images, axis=1)))
return images
def split_image(image: Image.Image) -> List[Image.Image]:
"""
Split an image into multiple views.
"""
image = np.array(image)
alpha = image[..., 3]
alpha = np.any(alpha>0, axis=0)
start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
images = []
for s, e in zip(start_pos, end_pos):
images.append(Image.fromarray(image[:, s:e+1]))
return [preprocess_image(image) for image in images]
def _example_to_multi(img: Image.Image):
imgs = split_image(img)
return imgs, imgs
def _files_to_gallery_and_state(file_list):
tuples = preprocess_upload_images(file_list)
gallery_imgs = [img for img, _ in tuples]
return gallery_imgs, tuples
@spaces.GPU(api_name="quick_generate_glb")
def quick_generate_glb(
image: Image.Image,
multiimages: List[Tuple[Image.Image, str]],
is_multiimage: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
multiimage_algo: Literal["multidiffusion", "stochastic"],
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
state, _ = image_to_3d(
image=image,
multiimages=multiimages,
is_multiimage=is_multiimage,
seed=seed,
ss_guidance_strength=ss_guidance_strength,
ss_sampling_steps=ss_sampling_steps,
slat_guidance_strength=slat_guidance_strength,
slat_sampling_steps=slat_sampling_steps,
multiimage_algo=multiimage_algo,
req=req
)
return extract_glb(state, mesh_simplify=mesh_simplify, texture_size=texture_size, req=req)
@spaces.GPU(api_name="quick_generate_gs")
def quick_generate_gs(
image: Image.Image,
multiimages: List[Tuple[Image.Image, str]],
is_multiimage: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
multiimage_algo: Literal["multidiffusion", "stochastic"],
req: gr.Request,
) -> Tuple[str, str]:
state, _ = image_to_3d(
image=image,
multiimages=multiimages,
is_multiimage=is_multiimage,
seed=seed,
ss_guidance_strength=ss_guidance_strength,
ss_sampling_steps=ss_sampling_steps,
slat_guidance_strength=slat_guidance_strength,
slat_sampling_steps=slat_sampling_steps,
multiimage_algo=multiimage_algo,
req=req
)
return extract_gaussian(state, req=req)
def test_for_api_gen(image: Image.Image) -> Image.Image:
"""
bilibili .
Args:
image (Image.Image): The input imagein hererererer.
Returns:
Image.Image: The preprocessed image no processs.
"""
return image
def update_is_multiimage(event: gr.SelectData):
return gr.update("true" if event.index == 1 else "false")
def toggle_multiimage_visibility(choice: str):
show = choice.lower() == "true"
return (
gr.update(visible=show), # uploaded_api_images
gr.update(visible=show) # multiimage_prompt (Gallery)
)
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
Thanks to the incredible work of [JeffreyXiang/TRELLIS-image-large](https://huggingface.co/JeffreyXiang/TRELLIS-image-large) for providing such a stunning implementation of the TRELLIS 3D pipeline.
During my usage, I noticed that many users had questions regarding API access. I've spent some time refactoring the `image_to_3d` pipeline and adding two new endpoints:
- 🔁 `quick_generate_glb`: Directly generate and download a `.glb` 3D asset.
- 🌐 `quick_generate_gs`: Directly generate and download the Gaussian `.ply` file.
- 🧩 Both functions are exposed as Hugging Face API endpoints and can be called via `gradio_client` or any HTTP client.
### How to Use:
- Upload an image and click **"Generate"** to create a 3D asset. If the image has an alpha channel, it will be used as a mask. Otherwise, `rembg` will automatically remove the background.
- If you're satisfied with the result, click **"Extract GLB"** or **"Extract Gaussian"** to download the 3D file.
### Features:
- ✅ Single-image and experimental multi-image generation
- ✅ `.glb` extraction with mesh simplification and texturing
- ✅ `.ply` (Gaussian) extraction
- ✅ Public API endpoints for one-click asset generation and download
Feel free to try it out and send feedback — I'm happy to keep improving it based on your suggestions!
""")
with gr.Row():
with gr.Column():
with gr.Tabs() as input_tabs:
with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
gr.Markdown("""
Input different views of the object in separate images.
*NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
""")
is_multiimage = gr.Textbox(value="false", visible=True, interactive=False, label="is_multiimage")
input_tabs.select(
fn=update_is_multiimage,
outputs=is_multiimage
)
uploaded_api_images = gr.Files(file_types=["image"], label="Upload Images")
multiimage_combined = gr.State()
is_multiimage.change(
fn=toggle_multiimage_visibility,
inputs=is_multiimage,
outputs=[uploaded_api_images, multiimage_prompt],
trigger_mode="multiple"
)
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
generate_btn = gr.Button("Generate")
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
with gr.Row():
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
with gr.Row():
quick_generate_glb_btn = gr.Button("Quick Generate GLB")
quick_generate_gs_btn = gr.Button("Quick Generate Gaussian")
gr.Markdown("""
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
""")
with gr.Column():
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
with gr.Row():
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
output_buf = gr.State()
# Example images at the bottom of the page
with gr.Row() as single_image_example:
examples = gr.Examples(
examples=[
f'assets/example_image/{image}'
for image in os.listdir("assets/example_image")
],
inputs=[image_prompt],
fn=preprocess_image,
outputs=[image_prompt],
run_on_click=True,
examples_per_page=64,
)
with gr.Row(visible=False) as multiimage_example:
examples_multi = gr.Examples(
examples=prepare_multi_example(),
inputs=[image_prompt],
fn=_example_to_multi,
outputs=[multiimage_prompt,
multiimage_combined],
run_on_click=True,
examples_per_page=8,
)
# Handlers
demo.load(start_session)
demo.unload(end_session)
single_image_input_tab.select(
lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
outputs=[is_multiimage, single_image_example, multiimage_example]
)
multiimage_input_tab.select(
lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
outputs=[is_multiimage, single_image_example, multiimage_example]
)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
# multiimage_prompt.upload(
# preprocess_images,
# inputs=[multiimage_prompt],
# outputs=[multiimage_prompt],
# )
multiimage_prompt.upload(
fn=lambda imgs: imgs,
inputs=[multiimage_prompt],
outputs=[multiimage_prompt, multiimage_combined],
)
uploaded_api_images.upload(
fn=_files_to_gallery_and_state,
inputs=[uploaded_api_images],
outputs=[multiimage_prompt, multiimage_combined],
preprocess=False,
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
image_to_3d,
inputs=[
image_prompt, multiimage_combined, is_multiimage, seed,
ss_guidance_strength, ss_sampling_steps,
slat_guidance_strength, slat_sampling_steps, multiimage_algo
],
outputs=[output_buf, video_output],
).then(
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
video_output.clear(
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size],
outputs=[model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
extract_gs_btn.click(
extract_gaussian,
inputs=[output_buf],
outputs=[model_output, download_gs],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_gs],
)
model_output.clear(
lambda: gr.Button(interactive=False),
outputs=[download_glb],
)
quick_generate_glb_btn.click(
fn=quick_generate_glb,
inputs=[
image_prompt,
multiimage_combined,
is_multiimage,
seed,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
multiimage_algo,
mesh_simplify,
texture_size,
],
outputs=[model_output, download_glb],
)
quick_generate_gs_btn.click(
fn=quick_generate_gs,
inputs=[
image_prompt,
multiimage_combined,
is_multiimage,
seed,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
multiimage_algo,
],
outputs=[model_output, download_gs],
)
generate_btn.click(
fn=image_to_3d,
inputs=[
image_prompt, # image: Image.Image
multiimage_combined, # multiimages: List[UploadedFile] or List[Tuple[Image, str]]
is_multiimage, # is_multiimage: str
seed,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
multiimage_algo,
],
outputs=[
output_buf,
video_output
]
)
# Launch the Gradio app
if __name__ == "__main__":
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
if torch.cuda.is_available():
pipeline.cuda()
print("CUDA is available. Using GPU.")
else:
print("CUDA not available. Falling back to CPU.")
try:
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
except:
pass
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"CUDA Version: {torch.version.cuda}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
demo.launch(debug=True)