dkatz2391 commited on
Commit
ea9eafd
·
verified ·
1 Parent(s): 96595f0

mimic fast api speearet file inegration file

Browse files
Files changed (1) hide show
  1. app.py +23 -104
app.py CHANGED
@@ -17,13 +17,9 @@ from trellis.utils import render_utils, postprocessing_utils
17
  import traceback
18
  import sys
19
 
20
- # --- FastAPI / Threading Imports ---
21
- import threading
22
- import uvicorn
23
  import logging
24
- from fastapi import FastAPI, HTTPException
25
- from pydantic import BaseModel
26
-
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
@@ -43,7 +39,11 @@ def start_session(req: gr.Request):
43
 
44
  def end_session(req: gr.Request):
45
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
46
- shutil.rmtree(user_dir)
 
 
 
 
47
 
48
 
49
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
@@ -63,7 +63,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
63
  }
64
 
65
 
66
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
67
  gs = Gaussian(
68
  aabb=state['gaussian']['aabb'],
69
  sh_degree=state['gaussian']['sh_degree'],
@@ -119,10 +119,16 @@ def text_to_3d(
119
  str: The path to the video of the 3D model.
120
  """
121
  # --- Determine user_dir robustly ---
122
- session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"api_call_{np.random.randint(10000)}"
123
  user_dir = os.path.join(TMP_DIR, session_hash_str)
124
  os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
125
 
 
 
 
 
 
 
126
  outputs = pipeline.run(
127
  prompt,
128
  seed=seed,
@@ -143,9 +149,7 @@ def text_to_3d(
143
  try:
144
  imageio.mimsave(video_path, video, fps=15) # Now the directory should exist
145
  except FileNotFoundError:
146
- print(f"ERROR: Directory {user_dir} still not found before mimsave!", file=sys.stderr)
147
- # Decide if we should raise or return an error state?
148
- # Returning a dummy path might hide the error, so let's raise for now
149
  raise
150
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
151
  torch.cuda.empty_cache()
@@ -171,7 +175,7 @@ def extract_glb(
171
  str: The path to the extracted GLB file.
172
  """
173
  # --- Determine user_dir robustly ---
174
- session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"api_call_{np.random.randint(10000)}"
175
  user_dir = os.path.join(TMP_DIR, session_hash_str)
176
  os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
177
 
@@ -181,7 +185,7 @@ def extract_glb(
181
  try:
182
  glb.export(glb_path) # Now the directory should exist
183
  except FileNotFoundError:
184
- print(f"ERROR: Directory {user_dir} still not found before glb.export!", file=sys.stderr)
185
  raise
186
  torch.cuda.empty_cache()
187
  return glb_path, glb_path
@@ -199,7 +203,7 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
199
  str: The path to the extracted Gaussian file.
200
  """
201
  # --- Determine user_dir robustly ---
202
- session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"api_call_{np.random.randint(10000)}"
203
  user_dir = os.path.join(TMP_DIR, session_hash_str)
204
  os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
205
 
@@ -208,84 +212,12 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
208
  try:
209
  gs.save_ply(gaussian_path) # Now the directory should exist
210
  except FileNotFoundError:
211
- print(f"ERROR: Directory {user_dir} still not found before gs.save_ply!", file=sys.stderr)
212
  raise
213
  torch.cuda.empty_cache()
214
  return gaussian_path, gaussian_path
215
 
216
 
217
- # --- FastAPI App Setup ---
218
- api_app = FastAPI()
219
-
220
- class GenerateRequest(BaseModel):
221
- prompt: str
222
- seed: int = 0 # Default seed
223
- mesh_simplify: float = 0.95 # Default simplify factor
224
- texture_size: int = 1024 # Default texture size
225
- # Add other generation parameters if needed
226
-
227
- @api_app.post("/api/generate-sync")
228
- async def generate_sync_api(request_data: GenerateRequest):
229
- global pipeline # Access the globally initialized pipeline
230
- if pipeline is None:
231
- logger.error("API Error: Pipeline not initialized")
232
- raise HTTPException(status_code=503, detail="Pipeline not ready")
233
-
234
- prompt = request_data.prompt
235
- seed = request_data.seed
236
- mesh_simplify = request_data.mesh_simplify
237
- texture_size = request_data.texture_size
238
- # Extract other params if added to GenerateRequest
239
-
240
- logger.info(f"API /generate-sync received prompt: {prompt}")
241
-
242
- try:
243
- # --- Determine a unique temporary directory for this API call ---
244
- api_call_hash = f"api_sync_{np.random.randint(100000)}"
245
- user_dir = os.path.join(TMP_DIR, api_call_hash)
246
- os.makedirs(user_dir, exist_ok=True)
247
- logger.info(f"API using temp dir: {user_dir}")
248
-
249
- # --- Stage 1: Run the text-to-3D pipeline ---
250
- logger.info("API running pipeline...")
251
- # Use default values for parameters not exposed in the simple API for now
252
- outputs = pipeline.run(
253
- prompt,
254
- seed=seed,
255
- formats=["gaussian", "mesh"],
256
- sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5},
257
- slat_sampler_params={"steps": 25, "cfg_strength": 7.5},
258
- )
259
- gs = outputs['gaussian'][0]
260
- mesh = outputs['mesh'][0]
261
- logger.info("API pipeline finished.")
262
- torch.cuda.empty_cache()
263
-
264
- # --- Stage 2: Extract GLB ---
265
- logger.info("API extracting GLB...")
266
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
267
- glb_path = os.path.join(user_dir, 'generated_sync.glb')
268
- glb.export(glb_path)
269
- logger.info(f"API GLB exported to: {glb_path}")
270
- torch.cuda.empty_cache()
271
-
272
- # Return the absolute path within the container
273
- return {"status": "success", "glb_path": os.path.abspath(glb_path)}
274
-
275
- except Exception as e:
276
- logger.error(f"API /generate-sync error: {str(e)}", exc_info=True)
277
- # Clean up temp dir on error if it exists
278
- if os.path.exists(user_dir):
279
- try:
280
- shutil.rmtree(user_dir)
281
- except Exception as cleanup_e:
282
- logger.error(f"API Error cleaning up dir {user_dir}: {cleanup_e}")
283
- raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
284
- # Note: We don't automatically clean up the user_dir on success,
285
- # as the file needs to be accessible for download by the calling server.
286
- # A separate cleanup mechanism might be needed eventually.
287
-
288
-
289
  # --- Gradio Blocks Definition ---
290
  with gr.Blocks(delete_cache=(600, 600)) as demo:
291
  gr.Markdown("""
@@ -379,29 +311,16 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
379
  )
380
 
381
 
382
- # --- Functions to Run FastAPI in Background ---
383
- def run_api():
384
- """Run the FastAPI server."""
385
- uvicorn.run(api_app, host="0.0.0.0", port=8000) # Run on port 8000
386
-
387
- def start_api_thread():
388
- """Start the API server in a background thread."""
389
- api_thread = threading.Thread(target=run_api, daemon=True)
390
- api_thread.start()
391
- logger.info("Started FastAPI server thread on port 8000")
392
- return api_thread
393
-
394
-
395
  # Launch the Gradio app and FastAPI server
396
  if __name__ == "__main__":
397
  logger.info("Initializing Trellis Pipeline...")
398
- # Make pipeline global so API endpoint can access it
399
  pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
400
  pipeline.cuda()
401
  logger.info("Trellis Pipeline Initialized.")
402
 
403
- # Start the background API server
404
- start_api_thread()
405
 
406
  # Launch the Gradio interface (blocking call)
407
  logger.info("Launching Gradio Demo...")
 
17
  import traceback
18
  import sys
19
 
20
+ # --- Import the FastAPI integration module ---
21
+ import trellis_fastAPI_integration
 
22
  import logging
 
 
 
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
39
 
40
  def end_session(req: gr.Request):
41
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
42
+ try:
43
+ if os.path.exists(user_dir):
44
+ shutil.rmtree(user_dir)
45
+ except OSError as e:
46
+ logger.warning(f"Warning: Could not remove temp session dir {user_dir}: {e}")
47
 
48
 
49
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
63
  }
64
 
65
 
66
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
67
  gs = Gaussian(
68
  aabb=state['gaussian']['aabb'],
69
  sh_degree=state['gaussian']['sh_degree'],
 
119
  str: The path to the video of the 3D model.
120
  """
121
  # --- Determine user_dir robustly ---
122
+ session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"gradio_call_{np.random.randint(10000)}"
123
  user_dir = os.path.join(TMP_DIR, session_hash_str)
124
  os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
125
 
126
+ # Use the global pipeline initialized later
127
+ if pipeline is None:
128
+ logger.error("Gradio Error: Pipeline not initialized")
129
+ # Handle error appropriately for Gradio - maybe return None or raise gr.Error?
130
+ return {}, None
131
+
132
  outputs = pipeline.run(
133
  prompt,
134
  seed=seed,
 
149
  try:
150
  imageio.mimsave(video_path, video, fps=15) # Now the directory should exist
151
  except FileNotFoundError:
152
+ logger.error(f"ERROR: Directory {user_dir} still not found before mimsave!", exc_info=True)
 
 
153
  raise
154
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
155
  torch.cuda.empty_cache()
 
175
  str: The path to the extracted GLB file.
176
  """
177
  # --- Determine user_dir robustly ---
178
+ session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"gradio_call_{np.random.randint(10000)}"
179
  user_dir = os.path.join(TMP_DIR, session_hash_str)
180
  os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
181
 
 
185
  try:
186
  glb.export(glb_path) # Now the directory should exist
187
  except FileNotFoundError:
188
+ logger.error(f"ERROR: Directory {user_dir} still not found before glb.export!", exc_info=True)
189
  raise
190
  torch.cuda.empty_cache()
191
  return glb_path, glb_path
 
203
  str: The path to the extracted Gaussian file.
204
  """
205
  # --- Determine user_dir robustly ---
206
+ session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"gradio_call_{np.random.randint(10000)}"
207
  user_dir = os.path.join(TMP_DIR, session_hash_str)
208
  os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
209
 
 
212
  try:
213
  gs.save_ply(gaussian_path) # Now the directory should exist
214
  except FileNotFoundError:
215
+ logger.error(f"ERROR: Directory {user_dir} still not found before gs.save_ply!", exc_info=True)
216
  raise
217
  torch.cuda.empty_cache()
218
  return gaussian_path, gaussian_path
219
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  # --- Gradio Blocks Definition ---
222
  with gr.Blocks(delete_cache=(600, 600)) as demo:
223
  gr.Markdown("""
 
311
  )
312
 
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  # Launch the Gradio app and FastAPI server
315
  if __name__ == "__main__":
316
  logger.info("Initializing Trellis Pipeline...")
317
+ # Make pipeline global so Gradio functions and API endpoint can access it
318
  pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
319
  pipeline.cuda()
320
  logger.info("Trellis Pipeline Initialized.")
321
 
322
+ # Start the background API server using the integration module
323
+ trellis_fastAPI_integration.start_api_thread(pipeline)
324
 
325
  # Launch the Gradio interface (blocking call)
326
  logger.info("Launching Gradio Demo...")