Spaces:
Running
on
Zero
Running
on
Zero
mimic fast api speearet file inegration file
Browse files
app.py
CHANGED
@@ -17,13 +17,9 @@ from trellis.utils import render_utils, postprocessing_utils
|
|
17 |
import traceback
|
18 |
import sys
|
19 |
|
20 |
-
# --- FastAPI
|
21 |
-
import
|
22 |
-
import uvicorn
|
23 |
import logging
|
24 |
-
from fastapi import FastAPI, HTTPException
|
25 |
-
from pydantic import BaseModel
|
26 |
-
|
27 |
|
28 |
MAX_SEED = np.iinfo(np.int32).max
|
29 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
@@ -43,7 +39,11 @@ def start_session(req: gr.Request):
|
|
43 |
|
44 |
def end_session(req: gr.Request):
|
45 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
@@ -63,7 +63,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
|
63 |
}
|
64 |
|
65 |
|
66 |
-
def unpack_state(state: dict) -> Tuple[Gaussian, edict
|
67 |
gs = Gaussian(
|
68 |
aabb=state['gaussian']['aabb'],
|
69 |
sh_degree=state['gaussian']['sh_degree'],
|
@@ -119,10 +119,16 @@ def text_to_3d(
|
|
119 |
str: The path to the video of the 3D model.
|
120 |
"""
|
121 |
# --- Determine user_dir robustly ---
|
122 |
-
session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"
|
123 |
user_dir = os.path.join(TMP_DIR, session_hash_str)
|
124 |
os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
outputs = pipeline.run(
|
127 |
prompt,
|
128 |
seed=seed,
|
@@ -143,9 +149,7 @@ def text_to_3d(
|
|
143 |
try:
|
144 |
imageio.mimsave(video_path, video, fps=15) # Now the directory should exist
|
145 |
except FileNotFoundError:
|
146 |
-
|
147 |
-
# Decide if we should raise or return an error state?
|
148 |
-
# Returning a dummy path might hide the error, so let's raise for now
|
149 |
raise
|
150 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
151 |
torch.cuda.empty_cache()
|
@@ -171,7 +175,7 @@ def extract_glb(
|
|
171 |
str: The path to the extracted GLB file.
|
172 |
"""
|
173 |
# --- Determine user_dir robustly ---
|
174 |
-
session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"
|
175 |
user_dir = os.path.join(TMP_DIR, session_hash_str)
|
176 |
os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
|
177 |
|
@@ -181,7 +185,7 @@ def extract_glb(
|
|
181 |
try:
|
182 |
glb.export(glb_path) # Now the directory should exist
|
183 |
except FileNotFoundError:
|
184 |
-
|
185 |
raise
|
186 |
torch.cuda.empty_cache()
|
187 |
return glb_path, glb_path
|
@@ -199,7 +203,7 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
|
199 |
str: The path to the extracted Gaussian file.
|
200 |
"""
|
201 |
# --- Determine user_dir robustly ---
|
202 |
-
session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"
|
203 |
user_dir = os.path.join(TMP_DIR, session_hash_str)
|
204 |
os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
|
205 |
|
@@ -208,84 +212,12 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
|
208 |
try:
|
209 |
gs.save_ply(gaussian_path) # Now the directory should exist
|
210 |
except FileNotFoundError:
|
211 |
-
|
212 |
raise
|
213 |
torch.cuda.empty_cache()
|
214 |
return gaussian_path, gaussian_path
|
215 |
|
216 |
|
217 |
-
# --- FastAPI App Setup ---
|
218 |
-
api_app = FastAPI()
|
219 |
-
|
220 |
-
class GenerateRequest(BaseModel):
|
221 |
-
prompt: str
|
222 |
-
seed: int = 0 # Default seed
|
223 |
-
mesh_simplify: float = 0.95 # Default simplify factor
|
224 |
-
texture_size: int = 1024 # Default texture size
|
225 |
-
# Add other generation parameters if needed
|
226 |
-
|
227 |
-
@api_app.post("/api/generate-sync")
|
228 |
-
async def generate_sync_api(request_data: GenerateRequest):
|
229 |
-
global pipeline # Access the globally initialized pipeline
|
230 |
-
if pipeline is None:
|
231 |
-
logger.error("API Error: Pipeline not initialized")
|
232 |
-
raise HTTPException(status_code=503, detail="Pipeline not ready")
|
233 |
-
|
234 |
-
prompt = request_data.prompt
|
235 |
-
seed = request_data.seed
|
236 |
-
mesh_simplify = request_data.mesh_simplify
|
237 |
-
texture_size = request_data.texture_size
|
238 |
-
# Extract other params if added to GenerateRequest
|
239 |
-
|
240 |
-
logger.info(f"API /generate-sync received prompt: {prompt}")
|
241 |
-
|
242 |
-
try:
|
243 |
-
# --- Determine a unique temporary directory for this API call ---
|
244 |
-
api_call_hash = f"api_sync_{np.random.randint(100000)}"
|
245 |
-
user_dir = os.path.join(TMP_DIR, api_call_hash)
|
246 |
-
os.makedirs(user_dir, exist_ok=True)
|
247 |
-
logger.info(f"API using temp dir: {user_dir}")
|
248 |
-
|
249 |
-
# --- Stage 1: Run the text-to-3D pipeline ---
|
250 |
-
logger.info("API running pipeline...")
|
251 |
-
# Use default values for parameters not exposed in the simple API for now
|
252 |
-
outputs = pipeline.run(
|
253 |
-
prompt,
|
254 |
-
seed=seed,
|
255 |
-
formats=["gaussian", "mesh"],
|
256 |
-
sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5},
|
257 |
-
slat_sampler_params={"steps": 25, "cfg_strength": 7.5},
|
258 |
-
)
|
259 |
-
gs = outputs['gaussian'][0]
|
260 |
-
mesh = outputs['mesh'][0]
|
261 |
-
logger.info("API pipeline finished.")
|
262 |
-
torch.cuda.empty_cache()
|
263 |
-
|
264 |
-
# --- Stage 2: Extract GLB ---
|
265 |
-
logger.info("API extracting GLB...")
|
266 |
-
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
267 |
-
glb_path = os.path.join(user_dir, 'generated_sync.glb')
|
268 |
-
glb.export(glb_path)
|
269 |
-
logger.info(f"API GLB exported to: {glb_path}")
|
270 |
-
torch.cuda.empty_cache()
|
271 |
-
|
272 |
-
# Return the absolute path within the container
|
273 |
-
return {"status": "success", "glb_path": os.path.abspath(glb_path)}
|
274 |
-
|
275 |
-
except Exception as e:
|
276 |
-
logger.error(f"API /generate-sync error: {str(e)}", exc_info=True)
|
277 |
-
# Clean up temp dir on error if it exists
|
278 |
-
if os.path.exists(user_dir):
|
279 |
-
try:
|
280 |
-
shutil.rmtree(user_dir)
|
281 |
-
except Exception as cleanup_e:
|
282 |
-
logger.error(f"API Error cleaning up dir {user_dir}: {cleanup_e}")
|
283 |
-
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
284 |
-
# Note: We don't automatically clean up the user_dir on success,
|
285 |
-
# as the file needs to be accessible for download by the calling server.
|
286 |
-
# A separate cleanup mechanism might be needed eventually.
|
287 |
-
|
288 |
-
|
289 |
# --- Gradio Blocks Definition ---
|
290 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
291 |
gr.Markdown("""
|
@@ -379,29 +311,16 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
379 |
)
|
380 |
|
381 |
|
382 |
-
# --- Functions to Run FastAPI in Background ---
|
383 |
-
def run_api():
|
384 |
-
"""Run the FastAPI server."""
|
385 |
-
uvicorn.run(api_app, host="0.0.0.0", port=8000) # Run on port 8000
|
386 |
-
|
387 |
-
def start_api_thread():
|
388 |
-
"""Start the API server in a background thread."""
|
389 |
-
api_thread = threading.Thread(target=run_api, daemon=True)
|
390 |
-
api_thread.start()
|
391 |
-
logger.info("Started FastAPI server thread on port 8000")
|
392 |
-
return api_thread
|
393 |
-
|
394 |
-
|
395 |
# Launch the Gradio app and FastAPI server
|
396 |
if __name__ == "__main__":
|
397 |
logger.info("Initializing Trellis Pipeline...")
|
398 |
-
# Make pipeline global so API endpoint can access it
|
399 |
pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
|
400 |
pipeline.cuda()
|
401 |
logger.info("Trellis Pipeline Initialized.")
|
402 |
|
403 |
-
# Start the background API server
|
404 |
-
start_api_thread()
|
405 |
|
406 |
# Launch the Gradio interface (blocking call)
|
407 |
logger.info("Launching Gradio Demo...")
|
|
|
17 |
import traceback
|
18 |
import sys
|
19 |
|
20 |
+
# --- Import the FastAPI integration module ---
|
21 |
+
import trellis_fastAPI_integration
|
|
|
22 |
import logging
|
|
|
|
|
|
|
23 |
|
24 |
MAX_SEED = np.iinfo(np.int32).max
|
25 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
|
|
39 |
|
40 |
def end_session(req: gr.Request):
|
41 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
42 |
+
try:
|
43 |
+
if os.path.exists(user_dir):
|
44 |
+
shutil.rmtree(user_dir)
|
45 |
+
except OSError as e:
|
46 |
+
logger.warning(f"Warning: Could not remove temp session dir {user_dir}: {e}")
|
47 |
|
48 |
|
49 |
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
|
|
63 |
}
|
64 |
|
65 |
|
66 |
+
def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
|
67 |
gs = Gaussian(
|
68 |
aabb=state['gaussian']['aabb'],
|
69 |
sh_degree=state['gaussian']['sh_degree'],
|
|
|
119 |
str: The path to the video of the 3D model.
|
120 |
"""
|
121 |
# --- Determine user_dir robustly ---
|
122 |
+
session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"gradio_call_{np.random.randint(10000)}"
|
123 |
user_dir = os.path.join(TMP_DIR, session_hash_str)
|
124 |
os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
|
125 |
|
126 |
+
# Use the global pipeline initialized later
|
127 |
+
if pipeline is None:
|
128 |
+
logger.error("Gradio Error: Pipeline not initialized")
|
129 |
+
# Handle error appropriately for Gradio - maybe return None or raise gr.Error?
|
130 |
+
return {}, None
|
131 |
+
|
132 |
outputs = pipeline.run(
|
133 |
prompt,
|
134 |
seed=seed,
|
|
|
149 |
try:
|
150 |
imageio.mimsave(video_path, video, fps=15) # Now the directory should exist
|
151 |
except FileNotFoundError:
|
152 |
+
logger.error(f"ERROR: Directory {user_dir} still not found before mimsave!", exc_info=True)
|
|
|
|
|
153 |
raise
|
154 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
155 |
torch.cuda.empty_cache()
|
|
|
175 |
str: The path to the extracted GLB file.
|
176 |
"""
|
177 |
# --- Determine user_dir robustly ---
|
178 |
+
session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"gradio_call_{np.random.randint(10000)}"
|
179 |
user_dir = os.path.join(TMP_DIR, session_hash_str)
|
180 |
os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
|
181 |
|
|
|
185 |
try:
|
186 |
glb.export(glb_path) # Now the directory should exist
|
187 |
except FileNotFoundError:
|
188 |
+
logger.error(f"ERROR: Directory {user_dir} still not found before glb.export!", exc_info=True)
|
189 |
raise
|
190 |
torch.cuda.empty_cache()
|
191 |
return glb_path, glb_path
|
|
|
203 |
str: The path to the extracted Gaussian file.
|
204 |
"""
|
205 |
# --- Determine user_dir robustly ---
|
206 |
+
session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"gradio_call_{np.random.randint(10000)}"
|
207 |
user_dir = os.path.join(TMP_DIR, session_hash_str)
|
208 |
os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
|
209 |
|
|
|
212 |
try:
|
213 |
gs.save_ply(gaussian_path) # Now the directory should exist
|
214 |
except FileNotFoundError:
|
215 |
+
logger.error(f"ERROR: Directory {user_dir} still not found before gs.save_ply!", exc_info=True)
|
216 |
raise
|
217 |
torch.cuda.empty_cache()
|
218 |
return gaussian_path, gaussian_path
|
219 |
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
# --- Gradio Blocks Definition ---
|
222 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
223 |
gr.Markdown("""
|
|
|
311 |
)
|
312 |
|
313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
# Launch the Gradio app and FastAPI server
|
315 |
if __name__ == "__main__":
|
316 |
logger.info("Initializing Trellis Pipeline...")
|
317 |
+
# Make pipeline global so Gradio functions and API endpoint can access it
|
318 |
pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
|
319 |
pipeline.cuda()
|
320 |
logger.info("Trellis Pipeline Initialized.")
|
321 |
|
322 |
+
# Start the background API server using the integration module
|
323 |
+
trellis_fastAPI_integration.start_api_thread(pipeline)
|
324 |
|
325 |
# Launch the Gradio interface (blocking call)
|
326 |
logger.info("Launching Gradio Demo...")
|