dkatz2391 commited on
Commit
e6539f9
·
verified ·
1 Parent(s): 3612d72

revert bck to memory error

Browse files
Files changed (1) hide show
  1. app.py +17 -56
app.py CHANGED
@@ -13,7 +13,6 @@ from easydict import EasyDict as edict
13
  from trellis.pipelines import TrellisTextTo3DPipeline
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
16
- import joblib # Added for saving/loading state
17
 
18
  import traceback
19
  import sys
@@ -90,7 +89,7 @@ def text_to_3d(
90
  slat_guidance_strength: float,
91
  slat_sampling_steps: int,
92
  req: gr.Request,
93
- ) -> Tuple[str, str, str]:
94
  """
95
  Convert an text prompt to a 3D model.
96
  Args:
@@ -101,9 +100,9 @@ def text_to_3d(
101
  slat_guidance_strength (float): The guidance strength for structured latent generation.
102
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
103
  Returns:
104
- str: Path to the saved state file.
105
- str: Path to the generated video.
106
- str: Path to the saved state file (for internal buffer).
107
  """
108
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
109
  os.makedirs(user_dir, exist_ok=True)
@@ -126,70 +125,34 @@ def text_to_3d(
126
  video_path = os.path.join(user_dir, 'sample.mp4')
127
  imageio.mimsave(video_path, video, fps=15)
128
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
129
-
130
- # Save state to file
131
- state_file_path = os.path.join(user_dir, f'state_{seed}.joblib')
132
- try:
133
- joblib.dump(state, state_file_path)
134
- print(f"[Trellis] State saved to {state_file_path}")
135
- except Exception as e:
136
- print(f"Error saving state to {state_file_path}: {e}")
137
- # Decide how to handle error - maybe return None or raise?
138
- # For now, let's allow it to proceed but log the error
139
- state_file_path = None # Indicate failure
140
-
141
  torch.cuda.empty_cache()
142
- # Return state file path for API, video path for Video, and state path again for internal buffer
143
- # Return None for path if saving failed
144
- return state_file_path, video_path, state_file_path
145
 
146
 
147
  @spaces.GPU(duration=90)
148
  def extract_glb(
149
- state_file_path: str, # Changed input from state: dict
150
  mesh_simplify: float,
151
  texture_size: int,
152
  req: gr.Request,
153
  ) -> Tuple[str, str]:
154
  """
155
- Extract a GLB file from the 3D model state file.
156
  Args:
157
- state_file_path (str): Path to the file containing the state.
158
  mesh_simplify (float): The mesh simplification factor.
159
  texture_size (int): The texture resolution.
160
  Returns:
161
  str: The path to the extracted GLB file.
162
- str: The path to the extracted GLB file (for download button).
163
  """
164
- if not state_file_path or not os.path.exists(state_file_path):
165
- print(f"Error: State file path invalid or file not found: {state_file_path}")
166
- # Return dummy paths or raise an error
167
- return None, None
168
-
169
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
170
  os.makedirs(user_dir, exist_ok=True)
171
-
172
- # Load state from file
173
- try:
174
- state = joblib.load(state_file_path)
175
- print(f"[Trellis] State loaded from {state_file_path}")
176
- except Exception as e:
177
- print(f"Error loading state from {state_file_path}: {e}")
178
- return None, None
179
-
180
  gs, mesh = unpack_state(state)
181
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
182
  glb_path = os.path.join(user_dir, 'sample.glb')
183
  glb.export(glb_path)
184
  torch.cuda.empty_cache()
185
-
186
- # Optional: Clean up the state file after use
187
- try:
188
- os.remove(state_file_path)
189
- print(f"[Trellis] Cleaned up state file: {state_file_path}")
190
- except OSError as e:
191
- print(f"Error removing state file {state_file_path}: {e.strerror}")
192
-
193
  return glb_path, glb_path
194
 
195
 
@@ -215,8 +178,8 @@ output_buf = gr.State()
215
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
216
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
217
 
218
- # Change hidden JSON to hidden Textbox for the state file path
219
- state_output_path_textbox = gr.Textbox(visible=False, label="State File Path Output")
220
 
221
  with gr.Blocks(delete_cache=(600, 600)) as demo:
222
  gr.Markdown("""
@@ -275,8 +238,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
275
  ).then(
276
  text_to_3d,
277
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
278
- # Output state path to hidden Textbox, video to Video, state path to internal buffer
279
- outputs=[state_output_path_textbox, video_output, output_buf],
280
  ).then(
281
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
282
  outputs=[extract_glb_btn, extract_gs_btn],
@@ -289,7 +252,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
289
 
290
  extract_glb_btn.click(
291
  extract_glb,
292
- # Input state path from internal buffer (assuming it holds the path now)
293
  inputs=[output_buf, mesh_simplify, texture_size],
294
  outputs=[model_output, download_glb],
295
  ).then(
@@ -299,8 +261,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
299
 
300
  extract_gs_btn.click(
301
  extract_gaussian,
302
- # This likely needs adjustment too if it relies on output_buf holding the state dict
303
- inputs=[output_buf],
304
  outputs=[model_output, download_gs],
305
  ).then(
306
  lambda: gr.Button(interactive=True),
@@ -344,11 +305,11 @@ api_text_to_3d = gr.Interface(
344
  # --- API-only endpoint for GLB extraction ---
345
  # Explicitly defines state input as JSON for server calls.
346
  api_extract_glb = gr.Interface(
347
- fn=lambda state_file_path, mesh_simplify, texture_size: extract_glb(
348
- state_file_path, mesh_simplify, texture_size, gr.Request()
349
  ),
350
  inputs=[
351
- gr.Textbox(label="State File Path"), # Expect state file path as string
352
  gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01),
353
  gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
354
  ],
 
13
  from trellis.pipelines import TrellisTextTo3DPipeline
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
 
16
 
17
  import traceback
18
  import sys
 
89
  slat_guidance_strength: float,
90
  slat_sampling_steps: int,
91
  req: gr.Request,
92
+ ) -> Tuple[dict, str, dict]:
93
  """
94
  Convert an text prompt to a 3D model.
95
  Args:
 
100
  slat_guidance_strength (float): The guidance strength for structured latent generation.
101
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
102
  Returns:
103
+ dict: The information of the generated 3D model.
104
+ str: The path to the video of the 3D model.
105
+ dict: The state of the generated 3D model.
106
  """
107
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
108
  os.makedirs(user_dir, exist_ok=True)
 
125
  video_path = os.path.join(user_dir, 'sample.mp4')
126
  imageio.mimsave(video_path, video, fps=15)
127
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
 
 
 
 
 
 
 
 
 
 
 
 
128
  torch.cuda.empty_cache()
129
+ # Return state for JSON, video path for Video, and state again for internal buffer
130
+ return state, video_path, state
 
131
 
132
 
133
  @spaces.GPU(duration=90)
134
  def extract_glb(
135
+ state: dict,
136
  mesh_simplify: float,
137
  texture_size: int,
138
  req: gr.Request,
139
  ) -> Tuple[str, str]:
140
  """
141
+ Extract a GLB file from the 3D model.
142
  Args:
143
+ state (dict): The state of the generated 3D model.
144
  mesh_simplify (float): The mesh simplification factor.
145
  texture_size (int): The texture resolution.
146
  Returns:
147
  str: The path to the extracted GLB file.
 
148
  """
 
 
 
 
 
149
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
150
  os.makedirs(user_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
151
  gs, mesh = unpack_state(state)
152
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
153
  glb_path = os.path.join(user_dir, 'sample.glb')
154
  glb.export(glb_path)
155
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
156
  return glb_path, glb_path
157
 
158
 
 
178
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
179
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
180
 
181
+ # Add a hidden JSON output for the state object for API calls
182
+ state_output_json = gr.JSON(visible=False, label="State JSON Output")
183
 
184
  with gr.Blocks(delete_cache=(600, 600)) as demo:
185
  gr.Markdown("""
 
238
  ).then(
239
  text_to_3d,
240
  inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
241
+ # Output state to hidden JSON first, then video to visible component, then state to internal buffer
242
+ outputs=[state_output_json, video_output, output_buf],
243
  ).then(
244
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
245
  outputs=[extract_glb_btn, extract_gs_btn],
 
252
 
253
  extract_glb_btn.click(
254
  extract_glb,
 
255
  inputs=[output_buf, mesh_simplify, texture_size],
256
  outputs=[model_output, download_glb],
257
  ).then(
 
261
 
262
  extract_gs_btn.click(
263
  extract_gaussian,
264
+ inputs=[output_buf],
 
265
  outputs=[model_output, download_gs],
266
  ).then(
267
  lambda: gr.Button(interactive=True),
 
305
  # --- API-only endpoint for GLB extraction ---
306
  # Explicitly defines state input as JSON for server calls.
307
  api_extract_glb = gr.Interface(
308
+ fn=lambda state, mesh_simplify, texture_size: extract_glb(
309
+ state, mesh_simplify, texture_size, gr.Request()
310
  ),
311
  inputs=[
312
+ gr.JSON(label="State Object"), # Expect state as JSON
313
  gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01),
314
  gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
315
  ],