dkatz2391 commited on
Commit
bcd68a1
·
verified ·
1 Parent(s): 7b1f05d

revert back to without fast api

Browse files
Files changed (1) hide show
  1. app.py +12 -97
app.py CHANGED
@@ -17,22 +17,11 @@ from trellis.utils import render_utils, postprocessing_utils
17
  import traceback
18
  import sys
19
 
20
- # --- Import the FastAPI integration module ---
21
- import trellis_fastAPI_integration
22
- import logging
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
26
  os.makedirs(TMP_DIR, exist_ok=True)
27
 
28
- # --- Global Pipeline Variable ---
29
- pipeline = None
30
-
31
- # --- Logging Setup ---
32
- logging.basicConfig(level=logging.INFO)
33
- logger = logging.getLogger(__name__)
34
-
35
- logger.info("Trellis App: Script starting.")
36
 
37
  def start_session(req: gr.Request):
38
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -41,11 +30,7 @@ def start_session(req: gr.Request):
41
 
42
  def end_session(req: gr.Request):
43
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
44
- try:
45
- if os.path.exists(user_dir):
46
- shutil.rmtree(user_dir)
47
- except OSError as e:
48
- logger.warning(f"Warning: Could not remove temp session dir {user_dir}: {e}")
49
 
50
 
51
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
@@ -65,7 +50,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
65
  }
66
 
67
 
68
- def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
69
  gs = Gaussian(
70
  aabb=state['gaussian']['aabb'],
71
  sh_degree=state['gaussian']['sh_degree'],
@@ -107,7 +92,6 @@ def text_to_3d(
107
  ) -> Tuple[dict, str]:
108
  """
109
  Convert an text prompt to a 3D model.
110
-
111
  Args:
112
  prompt (str): The text prompt.
113
  seed (int): The random seed.
@@ -115,22 +99,11 @@ def text_to_3d(
115
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
116
  slat_guidance_strength (float): The guidance strength for structured latent generation.
117
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
118
-
119
  Returns:
120
  dict: The information of the generated 3D model.
121
  str: The path to the video of the 3D model.
122
  """
123
- # --- Determine user_dir robustly ---
124
- session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"gradio_call_{np.random.randint(10000)}"
125
- user_dir = os.path.join(TMP_DIR, session_hash_str)
126
- os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
127
-
128
- # Use the global pipeline initialized later
129
- if pipeline is None:
130
- logger.error("Gradio Error: Pipeline not initialized")
131
- # Handle error appropriately for Gradio - maybe return None or raise gr.Error?
132
- return {}, None
133
-
134
  outputs = pipeline.run(
135
  prompt,
136
  seed=seed,
@@ -148,11 +121,7 @@ def text_to_3d(
148
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
149
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
150
  video_path = os.path.join(user_dir, 'sample.mp4')
151
- try:
152
- imageio.mimsave(video_path, video, fps=15) # Now the directory should exist
153
- except FileNotFoundError:
154
- logger.error(f"ERROR: Directory {user_dir} still not found before mimsave!", exc_info=True)
155
- raise
156
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
157
  torch.cuda.empty_cache()
158
  return state, video_path
@@ -167,28 +136,18 @@ def extract_glb(
167
  ) -> Tuple[str, str]:
168
  """
169
  Extract a GLB file from the 3D model.
170
-
171
  Args:
172
  state (dict): The state of the generated 3D model.
173
  mesh_simplify (float): The mesh simplification factor.
174
  texture_size (int): The texture resolution.
175
-
176
  Returns:
177
  str: The path to the extracted GLB file.
178
  """
179
- # --- Determine user_dir robustly ---
180
- session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"gradio_call_{np.random.randint(10000)}"
181
- user_dir = os.path.join(TMP_DIR, session_hash_str)
182
- os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
183
-
184
  gs, mesh = unpack_state(state)
185
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
186
  glb_path = os.path.join(user_dir, 'sample.glb')
187
- try:
188
- glb.export(glb_path) # Now the directory should exist
189
- except FileNotFoundError:
190
- logger.error(f"ERROR: Directory {user_dir} still not found before glb.export!", exc_info=True)
191
- raise
192
  torch.cuda.empty_cache()
193
  return glb_path, glb_path
194
 
@@ -197,30 +156,19 @@ def extract_glb(
197
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
198
  """
199
  Extract a Gaussian file from the 3D model.
200
-
201
  Args:
202
  state (dict): The state of the generated 3D model.
203
-
204
  Returns:
205
  str: The path to the extracted Gaussian file.
206
  """
207
- # --- Determine user_dir robustly ---
208
- session_hash_str = str(req.session_hash) if hasattr(req, 'session_hash') and req.session_hash else f"gradio_call_{np.random.randint(10000)}"
209
- user_dir = os.path.join(TMP_DIR, session_hash_str)
210
- os.makedirs(user_dir, exist_ok=True) # Ensure directory exists
211
-
212
  gs, _ = unpack_state(state)
213
  gaussian_path = os.path.join(user_dir, 'sample.ply')
214
- try:
215
- gs.save_ply(gaussian_path) # Now the directory should exist
216
- except FileNotFoundError:
217
- logger.error(f"ERROR: Directory {user_dir} still not found before gs.save_ply!", exc_info=True)
218
- raise
219
  torch.cuda.empty_cache()
220
  return gaussian_path, gaussian_path
221
 
222
 
223
- # --- Gradio Blocks Definition ---
224
  with gr.Blocks(delete_cache=(600, 600)) as demo:
225
  gr.Markdown("""
226
  ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
@@ -313,41 +261,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
313
  )
314
 
315
 
316
- # Launch the Gradio app and FastAPI server
317
  if __name__ == "__main__":
318
- logger.info("Trellis App: Initializing Trellis Pipeline...")
319
- try:
320
- # Make pipeline global so Gradio functions and API endpoint can access it
321
- pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
322
- pipeline.cuda()
323
- logger.info("Trellis App: Trellis Pipeline Initialized successfully.")
324
- except Exception as e:
325
- logger.error(f"Trellis App: FATAL ERROR initializing pipeline: {e}", exc_info=True)
326
- pipeline = None # Ensure pipeline is None if initialization failed
327
- # Optionally exit if pipeline is critical
328
- # import sys
329
- # sys.exit("Pipeline initialization failed.")
330
-
331
- # Start the background API server using the integration module only if pipeline loaded
332
- if pipeline:
333
- logger.info("Trellis App: Attempting to start FastAPI server thread...")
334
- try:
335
- api_thread = trellis_fastAPI_integration.start_api_thread(pipeline)
336
- if api_thread and api_thread.is_alive():
337
- logger.info("Trellis App: FastAPI server thread started successfully (is_alive check passed).")
338
- elif api_thread:
339
- logger.warning("Trellis App: FastAPI server thread was created but is not alive shortly after starting.")
340
- else:
341
- logger.error("Trellis App: start_api_thread returned None, thread not created.")
342
- except Exception as e:
343
- logger.error(f"Trellis App: Error occurred during start_api_thread call: {e}", exc_info=True)
344
- else:
345
- logger.error("Trellis App: Skipping FastAPI server start because pipeline failed to initialize.")
346
-
347
- # Launch the Gradio interface (blocking call)
348
- logger.info("Trellis App: Launching Gradio Demo...")
349
- try:
350
- demo.launch()
351
- logger.info("Trellis App: Gradio Demo launched.") # This might not be reached if launch blocks indefinitely
352
- except Exception as e:
353
- logger.error(f"Trellis App: Error launching Gradio demo: {e}", exc_info=True)
 
17
  import traceback
18
  import sys
19
 
 
 
 
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
  os.makedirs(TMP_DIR, exist_ok=True)
24
 
 
 
 
 
 
 
 
 
25
 
26
  def start_session(req: gr.Request):
27
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
30
 
31
  def end_session(req: gr.Request):
32
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
+ shutil.rmtree(user_dir)
 
 
 
 
34
 
35
 
36
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
50
  }
51
 
52
 
53
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
54
  gs = Gaussian(
55
  aabb=state['gaussian']['aabb'],
56
  sh_degree=state['gaussian']['sh_degree'],
 
92
  ) -> Tuple[dict, str]:
93
  """
94
  Convert an text prompt to a 3D model.
 
95
  Args:
96
  prompt (str): The text prompt.
97
  seed (int): The random seed.
 
99
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
100
  slat_guidance_strength (float): The guidance strength for structured latent generation.
101
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
 
102
  Returns:
103
  dict: The information of the generated 3D model.
104
  str: The path to the video of the 3D model.
105
  """
106
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
 
 
 
 
 
 
 
 
107
  outputs = pipeline.run(
108
  prompt,
109
  seed=seed,
 
121
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
122
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
123
  video_path = os.path.join(user_dir, 'sample.mp4')
124
+ imageio.mimsave(video_path, video, fps=15)
 
 
 
 
125
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
126
  torch.cuda.empty_cache()
127
  return state, video_path
 
136
  ) -> Tuple[str, str]:
137
  """
138
  Extract a GLB file from the 3D model.
 
139
  Args:
140
  state (dict): The state of the generated 3D model.
141
  mesh_simplify (float): The mesh simplification factor.
142
  texture_size (int): The texture resolution.
 
143
  Returns:
144
  str: The path to the extracted GLB file.
145
  """
146
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
 
 
147
  gs, mesh = unpack_state(state)
148
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
149
  glb_path = os.path.join(user_dir, 'sample.glb')
150
+ glb.export(glb_path)
 
 
 
 
151
  torch.cuda.empty_cache()
152
  return glb_path, glb_path
153
 
 
156
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
157
  """
158
  Extract a Gaussian file from the 3D model.
 
159
  Args:
160
  state (dict): The state of the generated 3D model.
 
161
  Returns:
162
  str: The path to the extracted Gaussian file.
163
  """
164
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
 
 
165
  gs, _ = unpack_state(state)
166
  gaussian_path = os.path.join(user_dir, 'sample.ply')
167
+ gs.save_ply(gaussian_path)
 
 
 
 
168
  torch.cuda.empty_cache()
169
  return gaussian_path, gaussian_path
170
 
171
 
 
172
  with gr.Blocks(delete_cache=(600, 600)) as demo:
173
  gr.Markdown("""
174
  ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
 
261
  )
262
 
263
 
264
+ # Launch the Gradio app
265
  if __name__ == "__main__":
266
+ pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
267
+ pipeline.cuda()
268
+ demo.launch()