TRELLIS_TextTo3D_Try2 / trellis_fastAPI_integration.py
dkatz2391's picture
Upload trellis_fastAPI_integration.py
64a7809 verified
raw
history blame
5.97 kB
# trellis_fastAPI_integration.py
# Version: 1.0.0
# a.1 Imports and Initial Setup
import os
import shutil
import threading
import uvicorn
import logging
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from easydict import EasyDict as edict # Assuming EasyDict might be needed if state used
# Assuming these are available or installed correctly in the environment
from trellis.utils import postprocessing_utils
# We get the pipeline object passed in, so no direct import needed here
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI app
api_app = FastAPI()
# --- Temporary Directory --- (Consistent with appTrellis.py)
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
# b.1 Request/Response Models
class GenerateRequest(BaseModel):
prompt: str
seed: int = 0 # Default seed
mesh_simplify: float = 0.95 # Default simplify factor
texture_size: int = 1024 # Default texture size
# Add other generation parameters if needed (e.g., guidance, steps)
# c.1 API Endpoint for Synchronous Generation
@api_app.post("/generate-sync")
async def generate_sync_api(request_data: GenerateRequest):
"""API endpoint to synchronously generate a model and return the GLB path."""
# Access the pipeline object stored in app state
pipeline = api_app.state.pipeline
if pipeline is None:
logger.error("API Error: Pipeline not initialized or passed correctly")
raise HTTPException(status_code=503, detail="Pipeline not ready")
prompt = request_data.prompt
seed = request_data.seed
mesh_simplify = request_data.mesh_simplify
texture_size = request_data.texture_size
# Extract other params if added to GenerateRequest
ss_sampling_steps = 25 # Example default
ss_guidance_strength = 7.5 # Example default
slat_sampling_steps = 25 # Example default
slat_guidance_strength = 7.5 # Example default
logger.info(f"API /generate-sync received prompt: {prompt}")
user_dir = None # Define user_dir outside try for cleanup
try:
# --- Determine a unique temporary directory for this API call ---
# Using a simpler random hash name for the API call directory
api_call_hash = f"api_sync_{np.random.randint(100000)}"
user_dir = os.path.join(TMP_DIR, api_call_hash)
os.makedirs(user_dir, exist_ok=True)
logger.info(f"API using temp dir: {user_dir}")
# --- Stage 1: Run the text-to-3D pipeline ---
logger.info("API running pipeline...")
# Ensure pipeline is run with appropriate parameters
outputs = pipeline.run(
prompt,
seed=seed,
formats=["gaussian", "mesh"],
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
gs = outputs['gaussian'][0] # Get the Gaussian representation
mesh = outputs['mesh'][0] # Get the Mesh representation
logger.info("API pipeline finished.")
torch.cuda.empty_cache()
# --- Stage 2: Extract GLB ---
logger.info("API extracting GLB...")
# Use the postprocessing utility
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'generated_sync.glb')
glb.export(glb_path)
logger.info(f"API GLB exported to: {glb_path}")
torch.cuda.empty_cache()
# Return the absolute path within the container
# This path needs to be accessible via the /file= route from outside
absolute_glb_path = os.path.abspath(glb_path)
logger.info(f"API returning absolute path: {absolute_glb_path}")
return {"status": "success", "glb_path": absolute_glb_path}
except Exception as e:
logger.error(f"API /generate-sync error: {str(e)}", exc_info=True)
# Clean up temp dir on error if it exists and was created
if user_dir and os.path.exists(user_dir):
try:
shutil.rmtree(user_dir)
logger.info(f"API cleaned up failed directory: {user_dir}")
except Exception as cleanup_e:
logger.error(f"API Error cleaning up dir {user_dir}: {cleanup_e}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
# Note: We don't automatically clean up the user_dir on success,
# as the file needs to be accessible for download by the calling server.
# A separate cleanup mechanism might be needed eventually.
# d.1 API Server Setup Functions
def run_api():
"""Run the FastAPI server."""
# Ensure pipeline is available in app state before starting
if not hasattr(api_app.state, 'pipeline') or api_app.state.pipeline is None:
logger.error("Cannot start API server: Pipeline object not found in app state.")
return
# Run on port 8000 - ensure this doesn't conflict if Gradio also tries this port
uvicorn.run(api_app, host="0.0.0.0", port=8000)
def start_api_thread(pipeline_object):
"""Start the API server in a background thread
Args:
pipeline_object: The initialized TrellisTextTo3DPipeline object
"""
# Store the passed pipeline object in the app's state
api_app.state.pipeline = pipeline_object
api_thread = threading.Thread(target=run_api, daemon=True)
api_thread.start()
logger.info("Started Trellis FastAPI integration server thread on port 8000")
return api_thread