File size: 7,301 Bytes
c38bb95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# trellis_fastAPI_integration.py
# Version: 1.0.0

# a.1 Imports and Initial Setup
import os
import shutil
import threading
import uvicorn
import logging
import numpy as np
import torch

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from easydict import EasyDict as edict # Assuming EasyDict might be needed if state used

# Assuming these are available or installed correctly in the environment
from trellis.utils import postprocessing_utils
# We get the pipeline object passed in, so no direct import needed here

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# FastAPI app
api_app = FastAPI()

# --- Temporary Directory --- (Consistent with appTrellis.py)
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)

# b.1 Request/Response Models
class GenerateRequest(BaseModel):
    prompt: str
    seed: int = 0 # Default seed
    mesh_simplify: float = 0.95 # Default simplify factor
    texture_size: int = 1024 # Default texture size
    # Add other generation parameters if needed (e.g., guidance, steps)

# c.1 API Endpoint for Synchronous Generation
@api_app.post("/generate-sync")
async def generate_sync_api(request_data: GenerateRequest):
    """API endpoint to synchronously generate a model and return the GLB path."""
    logger.info("API /generate-sync endpoint hit.") # Log when endpoint is called
    # Access the pipeline object stored in app state
    pipeline = api_app.state.pipeline
    if pipeline is None:
        logger.error("API Error: Pipeline not initialized or passed correctly")
        raise HTTPException(status_code=503, detail="Pipeline not ready")

    prompt = request_data.prompt
    seed = request_data.seed
    mesh_simplify = request_data.mesh_simplify
    texture_size = request_data.texture_size
    # Extract other params if added to GenerateRequest
    ss_sampling_steps = 25 # Example default
    ss_guidance_strength = 7.5 # Example default
    slat_sampling_steps = 25 # Example default
    slat_guidance_strength = 7.5 # Example default

    logger.info(f"API /generate-sync received prompt: {prompt}")
    user_dir = None # Define user_dir outside try for cleanup

    try:
        # --- Determine a unique temporary directory for this API call ---
        # Using a simpler random hash name for the API call directory
        api_call_hash = f"api_sync_{np.random.randint(100000)}"
        user_dir = os.path.join(TMP_DIR, api_call_hash)
        os.makedirs(user_dir, exist_ok=True)
        logger.info(f"API using temp dir: {user_dir}")

        # --- Stage 1: Run the text-to-3D pipeline ---
        logger.info("API running pipeline...")
        # Ensure pipeline is run with appropriate parameters
        outputs = pipeline.run(
            prompt,
            seed=seed,
            formats=["gaussian", "mesh"],
            sparse_structure_sampler_params={
                "steps": ss_sampling_steps,
                "cfg_strength": ss_guidance_strength,
            },
            slat_sampler_params={
                "steps": slat_sampling_steps,
                "cfg_strength": slat_guidance_strength,
            },
        )
        gs = outputs['gaussian'][0] # Get the Gaussian representation
        mesh = outputs['mesh'][0] # Get the Mesh representation
        logger.info("API pipeline finished.")
        torch.cuda.empty_cache()

        # --- Stage 2: Extract GLB ---
        logger.info("API extracting GLB...")
        # Use the postprocessing utility
        glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
        glb_path = os.path.join(user_dir, 'generated_sync.glb')
        glb.export(glb_path)
        logger.info(f"API GLB exported to: {glb_path}")
        torch.cuda.empty_cache()

        # Return the absolute path within the container
        # This path needs to be accessible via the /file= route from outside
        absolute_glb_path = os.path.abspath(glb_path)
        logger.info(f"API returning absolute path: {absolute_glb_path}")
        return {"status": "success", "glb_path": absolute_glb_path}

    except Exception as e:
        logger.error(f"API /generate-sync error: {str(e)}", exc_info=True)
        # Clean up temp dir on error if it exists and was created
        if user_dir and os.path.exists(user_dir):
            try:
                shutil.rmtree(user_dir)
                logger.info(f"API cleaned up failed directory: {user_dir}")
            except Exception as cleanup_e:
                logger.error(f"API Error cleaning up dir {user_dir}: {cleanup_e}")
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
    # Note: We don't automatically clean up the user_dir on success,
    # as the file needs to be accessible for download by the calling server.
    # A separate cleanup mechanism might be needed eventually.

# d.1 API Server Setup Functions
def run_api():
    """Run the FastAPI server."""
    logger.info("FastAPI Integration: run_api function called.")
    # Ensure pipeline is available in app state before starting
    if not hasattr(api_app.state, 'pipeline') or api_app.state.pipeline is None:
         logger.error("FastAPI Integration: Cannot start API server - Pipeline object not found in app state.")
         return
    logger.info("FastAPI Integration: Pipeline object found in state. Attempting to start Uvicorn...")
    # Run on port 8000 - ensure this doesn't conflict if Gradio also tries this port
    try:
        uvicorn.run(api_app, host="0.0.0.0", port=8000)
        logger.info("FastAPI Integration: Uvicorn server stopped.") # Logged when server exits cleanly
    except Exception as e:
        logger.error(f"FastAPI Integration: Uvicorn server failed to run or crashed: {e}", exc_info=True)

def start_api_thread(pipeline_object):
    """Start the API server in a background thread

    Args:
        pipeline_object: The initialized TrellisTextTo3DPipeline object
    """
    logger.info("FastAPI Integration: start_api_thread called.")
    # Store the passed pipeline object in the app's state
    if pipeline_object is None:
        logger.error("FastAPI Integration: start_api_thread received a None pipeline_object. Aborting thread start.")
        return None
    try:
        api_app.state.pipeline = pipeline_object
        logger.info("FastAPI Integration: Pipeline object successfully stored in app state.")
    except Exception as e:
        logger.error(f"FastAPI Integration: Failed to store pipeline object in app state: {e}", exc_info=True)
        return None

    logger.info("FastAPI Integration: Creating API thread...")
    api_thread = threading.Thread(target=run_api, daemon=True)

    logger.info("FastAPI Integration: Attempting to start API thread...")
    try:
        api_thread.start()
        logger.info("FastAPI Integration: API thread started (start() method called).")
    except Exception as e:
        logger.error(f"FastAPI Integration: Failed to start API thread: {e}", exc_info=True)
        return None # Indicate thread failed to start

    logger.info("Started Trellis FastAPI integration server thread function finished.") # Confirms this function completed
    return api_thread