dkatz2391 commited on
Commit
c38bb95
·
verified ·
1 Parent(s): 0f79bfb

Update trellis_fastAPI_integration.py

Browse files
Files changed (1) hide show
  1. trellis_fastAPI_integration.py +168 -143
trellis_fastAPI_integration.py CHANGED
@@ -1,143 +1,168 @@
1
- # trellis_fastAPI_integration.py
2
- # Version: 1.0.0
3
-
4
- # a.1 Imports and Initial Setup
5
- import os
6
- import shutil
7
- import threading
8
- import uvicorn
9
- import logging
10
- import numpy as np
11
- import torch
12
-
13
- from fastapi import FastAPI, HTTPException
14
- from pydantic import BaseModel
15
- from easydict import EasyDict as edict # Assuming EasyDict might be needed if state used
16
-
17
- # Assuming these are available or installed correctly in the environment
18
- from trellis.utils import postprocessing_utils
19
- # We get the pipeline object passed in, so no direct import needed here
20
-
21
- # Set up logging
22
- logging.basicConfig(level=logging.INFO)
23
- logger = logging.getLogger(__name__)
24
-
25
- # FastAPI app
26
- api_app = FastAPI()
27
-
28
- # --- Temporary Directory --- (Consistent with appTrellis.py)
29
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
30
- os.makedirs(TMP_DIR, exist_ok=True)
31
-
32
- # b.1 Request/Response Models
33
- class GenerateRequest(BaseModel):
34
- prompt: str
35
- seed: int = 0 # Default seed
36
- mesh_simplify: float = 0.95 # Default simplify factor
37
- texture_size: int = 1024 # Default texture size
38
- # Add other generation parameters if needed (e.g., guidance, steps)
39
-
40
- # c.1 API Endpoint for Synchronous Generation
41
- @api_app.post("/generate-sync")
42
- async def generate_sync_api(request_data: GenerateRequest):
43
- """API endpoint to synchronously generate a model and return the GLB path."""
44
- # Access the pipeline object stored in app state
45
- pipeline = api_app.state.pipeline
46
- if pipeline is None:
47
- logger.error("API Error: Pipeline not initialized or passed correctly")
48
- raise HTTPException(status_code=503, detail="Pipeline not ready")
49
-
50
- prompt = request_data.prompt
51
- seed = request_data.seed
52
- mesh_simplify = request_data.mesh_simplify
53
- texture_size = request_data.texture_size
54
- # Extract other params if added to GenerateRequest
55
- ss_sampling_steps = 25 # Example default
56
- ss_guidance_strength = 7.5 # Example default
57
- slat_sampling_steps = 25 # Example default
58
- slat_guidance_strength = 7.5 # Example default
59
-
60
- logger.info(f"API /generate-sync received prompt: {prompt}")
61
- user_dir = None # Define user_dir outside try for cleanup
62
-
63
- try:
64
- # --- Determine a unique temporary directory for this API call ---
65
- # Using a simpler random hash name for the API call directory
66
- api_call_hash = f"api_sync_{np.random.randint(100000)}"
67
- user_dir = os.path.join(TMP_DIR, api_call_hash)
68
- os.makedirs(user_dir, exist_ok=True)
69
- logger.info(f"API using temp dir: {user_dir}")
70
-
71
- # --- Stage 1: Run the text-to-3D pipeline ---
72
- logger.info("API running pipeline...")
73
- # Ensure pipeline is run with appropriate parameters
74
- outputs = pipeline.run(
75
- prompt,
76
- seed=seed,
77
- formats=["gaussian", "mesh"],
78
- sparse_structure_sampler_params={
79
- "steps": ss_sampling_steps,
80
- "cfg_strength": ss_guidance_strength,
81
- },
82
- slat_sampler_params={
83
- "steps": slat_sampling_steps,
84
- "cfg_strength": slat_guidance_strength,
85
- },
86
- )
87
- gs = outputs['gaussian'][0] # Get the Gaussian representation
88
- mesh = outputs['mesh'][0] # Get the Mesh representation
89
- logger.info("API pipeline finished.")
90
- torch.cuda.empty_cache()
91
-
92
- # --- Stage 2: Extract GLB ---
93
- logger.info("API extracting GLB...")
94
- # Use the postprocessing utility
95
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
96
- glb_path = os.path.join(user_dir, 'generated_sync.glb')
97
- glb.export(glb_path)
98
- logger.info(f"API GLB exported to: {glb_path}")
99
- torch.cuda.empty_cache()
100
-
101
- # Return the absolute path within the container
102
- # This path needs to be accessible via the /file= route from outside
103
- absolute_glb_path = os.path.abspath(glb_path)
104
- logger.info(f"API returning absolute path: {absolute_glb_path}")
105
- return {"status": "success", "glb_path": absolute_glb_path}
106
-
107
- except Exception as e:
108
- logger.error(f"API /generate-sync error: {str(e)}", exc_info=True)
109
- # Clean up temp dir on error if it exists and was created
110
- if user_dir and os.path.exists(user_dir):
111
- try:
112
- shutil.rmtree(user_dir)
113
- logger.info(f"API cleaned up failed directory: {user_dir}")
114
- except Exception as cleanup_e:
115
- logger.error(f"API Error cleaning up dir {user_dir}: {cleanup_e}")
116
- raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
117
- # Note: We don't automatically clean up the user_dir on success,
118
- # as the file needs to be accessible for download by the calling server.
119
- # A separate cleanup mechanism might be needed eventually.
120
-
121
- # d.1 API Server Setup Functions
122
- def run_api():
123
- """Run the FastAPI server."""
124
- # Ensure pipeline is available in app state before starting
125
- if not hasattr(api_app.state, 'pipeline') or api_app.state.pipeline is None:
126
- logger.error("Cannot start API server: Pipeline object not found in app state.")
127
- return
128
- # Run on port 8000 - ensure this doesn't conflict if Gradio also tries this port
129
- uvicorn.run(api_app, host="0.0.0.0", port=8000)
130
-
131
- def start_api_thread(pipeline_object):
132
- """Start the API server in a background thread
133
-
134
- Args:
135
- pipeline_object: The initialized TrellisTextTo3DPipeline object
136
- """
137
- # Store the passed pipeline object in the app's state
138
- api_app.state.pipeline = pipeline_object
139
-
140
- api_thread = threading.Thread(target=run_api, daemon=True)
141
- api_thread.start()
142
- logger.info("Started Trellis FastAPI integration server thread on port 8000")
143
- return api_thread
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # trellis_fastAPI_integration.py
2
+ # Version: 1.0.0
3
+
4
+ # a.1 Imports and Initial Setup
5
+ import os
6
+ import shutil
7
+ import threading
8
+ import uvicorn
9
+ import logging
10
+ import numpy as np
11
+ import torch
12
+
13
+ from fastapi import FastAPI, HTTPException
14
+ from pydantic import BaseModel
15
+ from easydict import EasyDict as edict # Assuming EasyDict might be needed if state used
16
+
17
+ # Assuming these are available or installed correctly in the environment
18
+ from trellis.utils import postprocessing_utils
19
+ # We get the pipeline object passed in, so no direct import needed here
20
+
21
+ # Set up logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # FastAPI app
26
+ api_app = FastAPI()
27
+
28
+ # --- Temporary Directory --- (Consistent with appTrellis.py)
29
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
30
+ os.makedirs(TMP_DIR, exist_ok=True)
31
+
32
+ # b.1 Request/Response Models
33
+ class GenerateRequest(BaseModel):
34
+ prompt: str
35
+ seed: int = 0 # Default seed
36
+ mesh_simplify: float = 0.95 # Default simplify factor
37
+ texture_size: int = 1024 # Default texture size
38
+ # Add other generation parameters if needed (e.g., guidance, steps)
39
+
40
+ # c.1 API Endpoint for Synchronous Generation
41
+ @api_app.post("/generate-sync")
42
+ async def generate_sync_api(request_data: GenerateRequest):
43
+ """API endpoint to synchronously generate a model and return the GLB path."""
44
+ logger.info("API /generate-sync endpoint hit.") # Log when endpoint is called
45
+ # Access the pipeline object stored in app state
46
+ pipeline = api_app.state.pipeline
47
+ if pipeline is None:
48
+ logger.error("API Error: Pipeline not initialized or passed correctly")
49
+ raise HTTPException(status_code=503, detail="Pipeline not ready")
50
+
51
+ prompt = request_data.prompt
52
+ seed = request_data.seed
53
+ mesh_simplify = request_data.mesh_simplify
54
+ texture_size = request_data.texture_size
55
+ # Extract other params if added to GenerateRequest
56
+ ss_sampling_steps = 25 # Example default
57
+ ss_guidance_strength = 7.5 # Example default
58
+ slat_sampling_steps = 25 # Example default
59
+ slat_guidance_strength = 7.5 # Example default
60
+
61
+ logger.info(f"API /generate-sync received prompt: {prompt}")
62
+ user_dir = None # Define user_dir outside try for cleanup
63
+
64
+ try:
65
+ # --- Determine a unique temporary directory for this API call ---
66
+ # Using a simpler random hash name for the API call directory
67
+ api_call_hash = f"api_sync_{np.random.randint(100000)}"
68
+ user_dir = os.path.join(TMP_DIR, api_call_hash)
69
+ os.makedirs(user_dir, exist_ok=True)
70
+ logger.info(f"API using temp dir: {user_dir}")
71
+
72
+ # --- Stage 1: Run the text-to-3D pipeline ---
73
+ logger.info("API running pipeline...")
74
+ # Ensure pipeline is run with appropriate parameters
75
+ outputs = pipeline.run(
76
+ prompt,
77
+ seed=seed,
78
+ formats=["gaussian", "mesh"],
79
+ sparse_structure_sampler_params={
80
+ "steps": ss_sampling_steps,
81
+ "cfg_strength": ss_guidance_strength,
82
+ },
83
+ slat_sampler_params={
84
+ "steps": slat_sampling_steps,
85
+ "cfg_strength": slat_guidance_strength,
86
+ },
87
+ )
88
+ gs = outputs['gaussian'][0] # Get the Gaussian representation
89
+ mesh = outputs['mesh'][0] # Get the Mesh representation
90
+ logger.info("API pipeline finished.")
91
+ torch.cuda.empty_cache()
92
+
93
+ # --- Stage 2: Extract GLB ---
94
+ logger.info("API extracting GLB...")
95
+ # Use the postprocessing utility
96
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
97
+ glb_path = os.path.join(user_dir, 'generated_sync.glb')
98
+ glb.export(glb_path)
99
+ logger.info(f"API GLB exported to: {glb_path}")
100
+ torch.cuda.empty_cache()
101
+
102
+ # Return the absolute path within the container
103
+ # This path needs to be accessible via the /file= route from outside
104
+ absolute_glb_path = os.path.abspath(glb_path)
105
+ logger.info(f"API returning absolute path: {absolute_glb_path}")
106
+ return {"status": "success", "glb_path": absolute_glb_path}
107
+
108
+ except Exception as e:
109
+ logger.error(f"API /generate-sync error: {str(e)}", exc_info=True)
110
+ # Clean up temp dir on error if it exists and was created
111
+ if user_dir and os.path.exists(user_dir):
112
+ try:
113
+ shutil.rmtree(user_dir)
114
+ logger.info(f"API cleaned up failed directory: {user_dir}")
115
+ except Exception as cleanup_e:
116
+ logger.error(f"API Error cleaning up dir {user_dir}: {cleanup_e}")
117
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
118
+ # Note: We don't automatically clean up the user_dir on success,
119
+ # as the file needs to be accessible for download by the calling server.
120
+ # A separate cleanup mechanism might be needed eventually.
121
+
122
+ # d.1 API Server Setup Functions
123
+ def run_api():
124
+ """Run the FastAPI server."""
125
+ logger.info("FastAPI Integration: run_api function called.")
126
+ # Ensure pipeline is available in app state before starting
127
+ if not hasattr(api_app.state, 'pipeline') or api_app.state.pipeline is None:
128
+ logger.error("FastAPI Integration: Cannot start API server - Pipeline object not found in app state.")
129
+ return
130
+ logger.info("FastAPI Integration: Pipeline object found in state. Attempting to start Uvicorn...")
131
+ # Run on port 8000 - ensure this doesn't conflict if Gradio also tries this port
132
+ try:
133
+ uvicorn.run(api_app, host="0.0.0.0", port=8000)
134
+ logger.info("FastAPI Integration: Uvicorn server stopped.") # Logged when server exits cleanly
135
+ except Exception as e:
136
+ logger.error(f"FastAPI Integration: Uvicorn server failed to run or crashed: {e}", exc_info=True)
137
+
138
+ def start_api_thread(pipeline_object):
139
+ """Start the API server in a background thread
140
+
141
+ Args:
142
+ pipeline_object: The initialized TrellisTextTo3DPipeline object
143
+ """
144
+ logger.info("FastAPI Integration: start_api_thread called.")
145
+ # Store the passed pipeline object in the app's state
146
+ if pipeline_object is None:
147
+ logger.error("FastAPI Integration: start_api_thread received a None pipeline_object. Aborting thread start.")
148
+ return None
149
+ try:
150
+ api_app.state.pipeline = pipeline_object
151
+ logger.info("FastAPI Integration: Pipeline object successfully stored in app state.")
152
+ except Exception as e:
153
+ logger.error(f"FastAPI Integration: Failed to store pipeline object in app state: {e}", exc_info=True)
154
+ return None
155
+
156
+ logger.info("FastAPI Integration: Creating API thread...")
157
+ api_thread = threading.Thread(target=run_api, daemon=True)
158
+
159
+ logger.info("FastAPI Integration: Attempting to start API thread...")
160
+ try:
161
+ api_thread.start()
162
+ logger.info("FastAPI Integration: API thread started (start() method called).")
163
+ except Exception as e:
164
+ logger.error(f"FastAPI Integration: Failed to start API thread: {e}", exc_info=True)
165
+ return None # Indicate thread failed to start
166
+
167
+ logger.info("Started Trellis FastAPI integration server thread function finished.") # Confirms this function completed
168
+ return api_thread