import gradio as gr from gradio.events import SelectData import spaces from gradio_litmodel3d import LitModel3D import json import os import shutil os.environ['SPCONV_ALGO'] = 'native' from typing import * import torch import numpy as np import imageio from pathlib import Path from easydict import EasyDict as edict from PIL import Image from trellis.pipelines import TrellisImageTo3DPipeline from trellis.representations import Gaussian, MeshExtractResult from trellis.utils import render_utils, postprocessing_utils from collections.abc import Sequence MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') os.makedirs(TMP_DIR, exist_ok=True) def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shutil.rmtree(user_dir) def preprocess_image(image: Image.Image) -> Image.Image: """ Preprocess the input image. Args: image (Image.Image): The input image. Returns: Image.Image: The preprocessed image. """ processed_image = pipeline.preprocess_image(image) return processed_image def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: """ Preprocess a list of input images. Args: images (List[Tuple[Image.Image, str]]): The input images. Returns: List[Image.Image]: The preprocessed images. """ images = [image[0] for image in images] processed_images = [pipeline.preprocess_image(image) for image in images] return processed_images def preprocess_upload_images(file_list: List[Any]) -> List[Tuple[Image.Image, str]]: images = [] for f in file_list: if isinstance(f, dict): path = f.get("path") or f.get("name") filename = os.path.basename(path) else: # UploadedFile / FileData path = f.name filename = os.path.basename(path) img = Image.open(path).convert("RGBA").resize( (518, 518), Image.Resampling.LANCZOS ) images.append((img, filename)) return images def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: return { 'gaussian': { **gs.init_params, '_xyz': gs._xyz.cpu().numpy(), '_features_dc': gs._features_dc.cpu().numpy(), '_scaling': gs._scaling.cpu().numpy(), '_rotation': gs._rotation.cpu().numpy(), '_opacity': gs._opacity.cpu().numpy(), }, 'mesh': { 'vertices': mesh.vertices.cpu().numpy(), 'faces': mesh.faces.cpu().numpy(), }, } def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: gs = Gaussian( aabb=state['gaussian']['aabb'], sh_degree=state['gaussian']['sh_degree'], mininum_kernel_size=state['gaussian']['mininum_kernel_size'], scaling_bias=state['gaussian']['scaling_bias'], opacity_bias=state['gaussian']['opacity_bias'], scaling_activation=state['gaussian']['scaling_activation'], ) gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') mesh = edict( vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), faces=torch.tensor(state['mesh']['faces'], device='cuda'), ) return gs, mesh def get_seed(randomize_seed: bool, seed: int) -> int: """ Get the random seed. """ return np.random.randint(0, MAX_SEED) if randomize_seed else seed def normalize_multiimages(multiimages: Sequence) -> List[Tuple[Image.Image, str]]: if not multiimages: return [] if isinstance(multiimages[0], Image.Image): return [ (pipeline.preprocess_image(img), f"gallery_{i}.png") for i, img in enumerate(multiimages) ] if isinstance(multiimages[0], tuple): return [ (pipeline.preprocess_image(img), name) for img, name in multiimages ] return preprocess_upload_images(multiimages) @spaces.GPU def image_to_3d( image: Image.Image, multiimages: List[Any], is_multiimage: str, seed: int, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int, multiimage_algo: Literal["multidiffusion", "stochastic"], req: gr.Request, ) -> Tuple[dict, str]: """ Convert an image (or multiple images) into a 3D model and return its state and video. Args: image (Image.Image): The input image for single-image mode. multiimages (List[Tuple[Image.Image, str]]): List of images with captions for multi-image mode. is_multiimage (str): Whether to use multi-image generation. seed (int): Random seed for reproducibility. ss_guidance_strength (float): Sparse structure guidance strength. ss_sampling_steps (int): Sparse structure sampling steps. slat_guidance_strength (float): SLAT guidance strength. slat_sampling_steps (int): SLAT sampling steps. multiimage_algo (str): Multi-image algorithm to use. Returns: dict: The information of the generated 3D model. str: The path to the video of the 3D model. """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) is_multiimage = is_multiimage.lower() == "true" multiimages = normalize_multiimages(multiimages) print("[DEBUG] is_multiimage:", is_multiimage, "num_imgs:", len(multiimages)) if is_multiimage and len(multiimages) == 0: is_multiimage = False # Run pipeline depending on mode if not is_multiimage: outputs = pipeline.run( image, seed=seed, formats=["gaussian", "mesh"], preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, }, slat_sampler_params={ "steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength, }, ) else: pil_images = [img for img, _ in multiimages] assert all(isinstance(im, Image.Image) for im in pil_images) outputs = pipeline.run_multi_image( pil_images, seed=seed, formats=["gaussian", "mesh"], preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, }, slat_sampler_params={ "steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength, }, mode=multiimage_algo, ) # Render the 3D video combining color and geometry video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] # Save the video video_path = os.path.join(user_dir, 'sample.mp4') imageio.mimsave(video_path, video, fps=15) # Pack state for downstream use state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) torch.cuda.empty_cache() return state, video_path @spaces.GPU(duration=90) def extract_glb( state: dict, mesh_simplify: float, texture_size: int, req: gr.Request, ) -> Tuple[str, str]: """ Extract a GLB file from the 3D model. Args: state (dict): The state of the generated 3D model. mesh_simplify (float): The mesh simplification factor. texture_size (int): The texture resolution. Returns: str: The path to the extracted GLB file. """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) gs, mesh = unpack_state(state) glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) glb_path = os.path.join(user_dir, 'sample.glb') glb.export(glb_path) torch.cuda.empty_cache() return glb_path, glb_path @spaces.GPU def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]: """ Extract a Gaussian file from the 3D model. Args: state (dict): The state of the generated 3D model. Returns: str: The path to the extracted Gaussian file. """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) gs, _ = unpack_state(state) gaussian_path = os.path.join(user_dir, 'sample.ply') gs.save_ply(gaussian_path) torch.cuda.empty_cache() return gaussian_path, gaussian_path def prepare_multi_example() -> List[Image.Image]: multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")])) images = [] for case in multi_case: _images = [] for i in range(1, 4): img = Image.open(f'assets/example_multi_image/{case}_{i}.png') W, H = img.size img = img.resize((int(W / H * 512), 512)) _images.append(np.array(img)) images.append(Image.fromarray(np.concatenate(_images, axis=1))) return images def split_image(image: Image.Image) -> List[Image.Image]: """ Split an image into multiple views. """ image = np.array(image) alpha = image[..., 3] alpha = np.any(alpha>0, axis=0) start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist() end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist() images = [] for s, e in zip(start_pos, end_pos): images.append(Image.fromarray(image[:, s:e+1])) return [preprocess_image(image) for image in images] def _example_to_multi(img: Image.Image): imgs = split_image(img) return imgs, imgs def _files_to_gallery_and_state(file_list): tuples = preprocess_upload_images(file_list) gallery_imgs = [img for img, _ in tuples] return gallery_imgs, tuples @spaces.GPU(api_name="quick_generate_glb") def quick_generate_glb( image: Image.Image, multiimages: List[Tuple[Image.Image, str]], is_multiimage: str, seed: int, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int, multiimage_algo: Literal["multidiffusion", "stochastic"], mesh_simplify: float, texture_size: int, req: gr.Request, ) -> Tuple[str, str]: state, _ = image_to_3d( image=image, multiimages=multiimages, is_multiimage=is_multiimage, seed=seed, ss_guidance_strength=ss_guidance_strength, ss_sampling_steps=ss_sampling_steps, slat_guidance_strength=slat_guidance_strength, slat_sampling_steps=slat_sampling_steps, multiimage_algo=multiimage_algo, req=req ) return extract_glb(state, mesh_simplify=mesh_simplify, texture_size=texture_size, req=req) @spaces.GPU(api_name="quick_generate_gs") def quick_generate_gs( image: Image.Image, multiimages: List[Tuple[Image.Image, str]], is_multiimage: str, seed: int, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int, multiimage_algo: Literal["multidiffusion", "stochastic"], req: gr.Request, ) -> Tuple[str, str]: state, _ = image_to_3d( image=image, multiimages=multiimages, is_multiimage=is_multiimage, seed=seed, ss_guidance_strength=ss_guidance_strength, ss_sampling_steps=ss_sampling_steps, slat_guidance_strength=slat_guidance_strength, slat_sampling_steps=slat_sampling_steps, multiimage_algo=multiimage_algo, req=req ) return extract_gaussian(state, req=req) def test_for_api_gen(image: Image.Image) -> Image.Image: """ bilibili . Args: image (Image.Image): The input imagein hererererer. Returns: Image.Image: The preprocessed image no processs. """ return image def update_is_multiimage(event: gr.SelectData): return gr.update("true" if event.index == 1 else "false") def toggle_multiimage_visibility(choice: str): show = choice.lower() == "true" return ( gr.update(visible=show), # uploaded_api_images gr.update(visible=show) # multiimage_prompt (Gallery) ) with gr.Blocks(delete_cache=(600, 600)) as demo: gr.Markdown(""" ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/) Thanks to the incredible work of [JeffreyXiang/TRELLIS-image-large](https://huggingface.co/JeffreyXiang/TRELLIS-image-large) for providing such a stunning implementation of the TRELLIS 3D pipeline. During my usage, I noticed that many users had questions regarding API access. I've spent some time refactoring the `image_to_3d` pipeline and adding two new endpoints: - 🔁 `quick_generate_glb`: Directly generate and download a `.glb` 3D asset. - 🌐 `quick_generate_gs`: Directly generate and download the Gaussian `.ply` file. - 🧩 Both functions are exposed as Hugging Face API endpoints and can be called via `gradio_client` or any HTTP client. ### How to Use: - Upload an image and click **"Generate"** to create a 3D asset. If the image has an alpha channel, it will be used as a mask. Otherwise, `rembg` will automatically remove the background. - If you're satisfied with the result, click **"Extract GLB"** or **"Extract Gaussian"** to download the 3D file. ### Features: - ✅ Single-image and experimental multi-image generation - ✅ `.glb` extraction with mesh simplification and texturing - ✅ `.ply` (Gaussian) extraction - ✅ Public API endpoints for one-click asset generation and download Feel free to try it out and send feedback — I'm happy to keep improving it based on your suggestions! """) with gr.Row(): with gr.Column(): with gr.Tabs() as input_tabs: with gr.Tab(label="Single Image", id=0) as single_image_input_tab: image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300) with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab: multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3) gr.Markdown(""" Input different views of the object in separate images. *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.* """) is_multiimage = gr.Textbox(value="false", visible=True, interactive=False, label="is_multiimage") input_tabs.select( fn=update_is_multiimage, outputs=is_multiimage ) uploaded_api_images = gr.Files(file_types=["image"], label="Upload Images") multiimage_combined = gr.State() is_multiimage.change( fn=toggle_multiimage_visibility, inputs=is_multiimage, outputs=[uploaded_api_images, multiimage_prompt], trigger_mode="multiple" ) with gr.Accordion(label="Generation Settings", open=False): seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) gr.Markdown("Stage 2: Structured Latent Generation") with gr.Row(): slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic") generate_btn = gr.Button("Generate") with gr.Accordion(label="GLB Extraction Settings", open=False): mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) with gr.Row(): extract_glb_btn = gr.Button("Extract GLB", interactive=False) extract_gs_btn = gr.Button("Extract Gaussian", interactive=False) with gr.Row(): quick_generate_glb_btn = gr.Button("Quick Generate GLB") quick_generate_gs_btn = gr.Button("Quick Generate Gaussian") gr.Markdown(""" *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.* """) with gr.Column(): video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300) with gr.Row(): download_glb = gr.DownloadButton(label="Download GLB", interactive=False) download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False) output_buf = gr.State() # Example images at the bottom of the page with gr.Row() as single_image_example: examples = gr.Examples( examples=[ f'assets/example_image/{image}' for image in os.listdir("assets/example_image") ], inputs=[image_prompt], fn=preprocess_image, outputs=[image_prompt], run_on_click=True, examples_per_page=64, ) with gr.Row(visible=False) as multiimage_example: examples_multi = gr.Examples( examples=prepare_multi_example(), inputs=[image_prompt], fn=_example_to_multi, outputs=[multiimage_prompt, multiimage_combined], run_on_click=True, examples_per_page=8, ) # Handlers demo.load(start_session) demo.unload(end_session) single_image_input_tab.select( lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]), outputs=[is_multiimage, single_image_example, multiimage_example] ) multiimage_input_tab.select( lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]), outputs=[is_multiimage, single_image_example, multiimage_example] ) image_prompt.upload( preprocess_image, inputs=[image_prompt], outputs=[image_prompt], ) # multiimage_prompt.upload( # preprocess_images, # inputs=[multiimage_prompt], # outputs=[multiimage_prompt], # ) multiimage_prompt.upload( fn=lambda imgs: imgs, inputs=[multiimage_prompt], outputs=[multiimage_prompt, multiimage_combined], ) uploaded_api_images.upload( fn=_files_to_gallery_and_state, inputs=[uploaded_api_images], outputs=[multiimage_prompt, multiimage_combined], preprocess=False, ) generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).then( image_to_3d, inputs=[ image_prompt, multiimage_combined, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo ], outputs=[output_buf, video_output], ).then( lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), outputs=[extract_glb_btn, extract_gs_btn], ) video_output.clear( lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]), outputs=[extract_glb_btn, extract_gs_btn], ) extract_glb_btn.click( extract_glb, inputs=[output_buf, mesh_simplify, texture_size], outputs=[model_output, download_glb], ).then( lambda: gr.Button(interactive=True), outputs=[download_glb], ) extract_gs_btn.click( extract_gaussian, inputs=[output_buf], outputs=[model_output, download_gs], ).then( lambda: gr.Button(interactive=True), outputs=[download_gs], ) model_output.clear( lambda: gr.Button(interactive=False), outputs=[download_glb], ) quick_generate_glb_btn.click( fn=quick_generate_glb, inputs=[ image_prompt, multiimage_combined, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size, ], outputs=[model_output, download_glb], ) quick_generate_gs_btn.click( fn=quick_generate_gs, inputs=[ image_prompt, multiimage_combined, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, ], outputs=[model_output, download_gs], ) generate_btn.click( fn=image_to_3d, inputs=[ image_prompt, # image: Image.Image multiimage_combined, # multiimages: List[UploadedFile] or List[Tuple[Image, str]] is_multiimage, # is_multiimage: str seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, ], outputs=[ output_buf, video_output ] ) # Launch the Gradio app if __name__ == "__main__": pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") if torch.cuda.is_available(): pipeline.cuda() print("CUDA is available. Using GPU.") else: print("CUDA not available. Falling back to CPU.") try: pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg except: pass print(f"CUDA Available: {torch.cuda.is_available()}") print(f"CUDA Version: {torch.version.cuda}") print(f"Number of GPUs: {torch.cuda.device_count()}") demo.launch(debug=True)