Spaces:
Running
on
Zero
Running
on
Zero
revert bck to memory error
Browse files
app.py
CHANGED
@@ -13,7 +13,6 @@ from easydict import EasyDict as edict
|
|
13 |
from trellis.pipelines import TrellisTextTo3DPipeline
|
14 |
from trellis.representations import Gaussian, MeshExtractResult
|
15 |
from trellis.utils import render_utils, postprocessing_utils
|
16 |
-
import joblib # Added for saving/loading state
|
17 |
|
18 |
import traceback
|
19 |
import sys
|
@@ -90,7 +89,7 @@ def text_to_3d(
|
|
90 |
slat_guidance_strength: float,
|
91 |
slat_sampling_steps: int,
|
92 |
req: gr.Request,
|
93 |
-
) -> Tuple[
|
94 |
"""
|
95 |
Convert an text prompt to a 3D model.
|
96 |
Args:
|
@@ -101,9 +100,9 @@ def text_to_3d(
|
|
101 |
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
102 |
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
103 |
Returns:
|
104 |
-
|
105 |
-
str:
|
106 |
-
|
107 |
"""
|
108 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
109 |
os.makedirs(user_dir, exist_ok=True)
|
@@ -126,70 +125,34 @@ def text_to_3d(
|
|
126 |
video_path = os.path.join(user_dir, 'sample.mp4')
|
127 |
imageio.mimsave(video_path, video, fps=15)
|
128 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
129 |
-
|
130 |
-
# Save state to file
|
131 |
-
state_file_path = os.path.join(user_dir, f'state_{seed}.joblib')
|
132 |
-
try:
|
133 |
-
joblib.dump(state, state_file_path)
|
134 |
-
print(f"[Trellis] State saved to {state_file_path}")
|
135 |
-
except Exception as e:
|
136 |
-
print(f"Error saving state to {state_file_path}: {e}")
|
137 |
-
# Decide how to handle error - maybe return None or raise?
|
138 |
-
# For now, let's allow it to proceed but log the error
|
139 |
-
state_file_path = None # Indicate failure
|
140 |
-
|
141 |
torch.cuda.empty_cache()
|
142 |
-
# Return state
|
143 |
-
|
144 |
-
return state_file_path, video_path, state_file_path
|
145 |
|
146 |
|
147 |
@spaces.GPU(duration=90)
|
148 |
def extract_glb(
|
149 |
-
|
150 |
mesh_simplify: float,
|
151 |
texture_size: int,
|
152 |
req: gr.Request,
|
153 |
) -> Tuple[str, str]:
|
154 |
"""
|
155 |
-
Extract a GLB file from the 3D model
|
156 |
Args:
|
157 |
-
|
158 |
mesh_simplify (float): The mesh simplification factor.
|
159 |
texture_size (int): The texture resolution.
|
160 |
Returns:
|
161 |
str: The path to the extracted GLB file.
|
162 |
-
str: The path to the extracted GLB file (for download button).
|
163 |
"""
|
164 |
-
if not state_file_path or not os.path.exists(state_file_path):
|
165 |
-
print(f"Error: State file path invalid or file not found: {state_file_path}")
|
166 |
-
# Return dummy paths or raise an error
|
167 |
-
return None, None
|
168 |
-
|
169 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
170 |
os.makedirs(user_dir, exist_ok=True)
|
171 |
-
|
172 |
-
# Load state from file
|
173 |
-
try:
|
174 |
-
state = joblib.load(state_file_path)
|
175 |
-
print(f"[Trellis] State loaded from {state_file_path}")
|
176 |
-
except Exception as e:
|
177 |
-
print(f"Error loading state from {state_file_path}: {e}")
|
178 |
-
return None, None
|
179 |
-
|
180 |
gs, mesh = unpack_state(state)
|
181 |
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
182 |
glb_path = os.path.join(user_dir, 'sample.glb')
|
183 |
glb.export(glb_path)
|
184 |
torch.cuda.empty_cache()
|
185 |
-
|
186 |
-
# Optional: Clean up the state file after use
|
187 |
-
try:
|
188 |
-
os.remove(state_file_path)
|
189 |
-
print(f"[Trellis] Cleaned up state file: {state_file_path}")
|
190 |
-
except OSError as e:
|
191 |
-
print(f"Error removing state file {state_file_path}: {e.strerror}")
|
192 |
-
|
193 |
return glb_path, glb_path
|
194 |
|
195 |
|
@@ -215,8 +178,8 @@ output_buf = gr.State()
|
|
215 |
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
216 |
model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
|
217 |
|
218 |
-
#
|
219 |
-
|
220 |
|
221 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
222 |
gr.Markdown("""
|
@@ -275,8 +238,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
275 |
).then(
|
276 |
text_to_3d,
|
277 |
inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
|
278 |
-
# Output state
|
279 |
-
outputs=[
|
280 |
).then(
|
281 |
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
282 |
outputs=[extract_glb_btn, extract_gs_btn],
|
@@ -289,7 +252,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
289 |
|
290 |
extract_glb_btn.click(
|
291 |
extract_glb,
|
292 |
-
# Input state path from internal buffer (assuming it holds the path now)
|
293 |
inputs=[output_buf, mesh_simplify, texture_size],
|
294 |
outputs=[model_output, download_glb],
|
295 |
).then(
|
@@ -299,8 +261,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
299 |
|
300 |
extract_gs_btn.click(
|
301 |
extract_gaussian,
|
302 |
-
|
303 |
-
inputs=[output_buf],
|
304 |
outputs=[model_output, download_gs],
|
305 |
).then(
|
306 |
lambda: gr.Button(interactive=True),
|
@@ -344,11 +305,11 @@ api_text_to_3d = gr.Interface(
|
|
344 |
# --- API-only endpoint for GLB extraction ---
|
345 |
# Explicitly defines state input as JSON for server calls.
|
346 |
api_extract_glb = gr.Interface(
|
347 |
-
fn=lambda
|
348 |
-
|
349 |
),
|
350 |
inputs=[
|
351 |
-
gr.
|
352 |
gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01),
|
353 |
gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
354 |
],
|
|
|
13 |
from trellis.pipelines import TrellisTextTo3DPipeline
|
14 |
from trellis.representations import Gaussian, MeshExtractResult
|
15 |
from trellis.utils import render_utils, postprocessing_utils
|
|
|
16 |
|
17 |
import traceback
|
18 |
import sys
|
|
|
89 |
slat_guidance_strength: float,
|
90 |
slat_sampling_steps: int,
|
91 |
req: gr.Request,
|
92 |
+
) -> Tuple[dict, str, dict]:
|
93 |
"""
|
94 |
Convert an text prompt to a 3D model.
|
95 |
Args:
|
|
|
100 |
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
101 |
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
102 |
Returns:
|
103 |
+
dict: The information of the generated 3D model.
|
104 |
+
str: The path to the video of the 3D model.
|
105 |
+
dict: The state of the generated 3D model.
|
106 |
"""
|
107 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
108 |
os.makedirs(user_dir, exist_ok=True)
|
|
|
125 |
video_path = os.path.join(user_dir, 'sample.mp4')
|
126 |
imageio.mimsave(video_path, video, fps=15)
|
127 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
torch.cuda.empty_cache()
|
129 |
+
# Return state for JSON, video path for Video, and state again for internal buffer
|
130 |
+
return state, video_path, state
|
|
|
131 |
|
132 |
|
133 |
@spaces.GPU(duration=90)
|
134 |
def extract_glb(
|
135 |
+
state: dict,
|
136 |
mesh_simplify: float,
|
137 |
texture_size: int,
|
138 |
req: gr.Request,
|
139 |
) -> Tuple[str, str]:
|
140 |
"""
|
141 |
+
Extract a GLB file from the 3D model.
|
142 |
Args:
|
143 |
+
state (dict): The state of the generated 3D model.
|
144 |
mesh_simplify (float): The mesh simplification factor.
|
145 |
texture_size (int): The texture resolution.
|
146 |
Returns:
|
147 |
str: The path to the extracted GLB file.
|
|
|
148 |
"""
|
|
|
|
|
|
|
|
|
|
|
149 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
150 |
os.makedirs(user_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
gs, mesh = unpack_state(state)
|
152 |
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
153 |
glb_path = os.path.join(user_dir, 'sample.glb')
|
154 |
glb.export(glb_path)
|
155 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
return glb_path, glb_path
|
157 |
|
158 |
|
|
|
178 |
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
179 |
model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
|
180 |
|
181 |
+
# Add a hidden JSON output for the state object for API calls
|
182 |
+
state_output_json = gr.JSON(visible=False, label="State JSON Output")
|
183 |
|
184 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
185 |
gr.Markdown("""
|
|
|
238 |
).then(
|
239 |
text_to_3d,
|
240 |
inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
|
241 |
+
# Output state to hidden JSON first, then video to visible component, then state to internal buffer
|
242 |
+
outputs=[state_output_json, video_output, output_buf],
|
243 |
).then(
|
244 |
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
245 |
outputs=[extract_glb_btn, extract_gs_btn],
|
|
|
252 |
|
253 |
extract_glb_btn.click(
|
254 |
extract_glb,
|
|
|
255 |
inputs=[output_buf, mesh_simplify, texture_size],
|
256 |
outputs=[model_output, download_glb],
|
257 |
).then(
|
|
|
261 |
|
262 |
extract_gs_btn.click(
|
263 |
extract_gaussian,
|
264 |
+
inputs=[output_buf],
|
|
|
265 |
outputs=[model_output, download_gs],
|
266 |
).then(
|
267 |
lambda: gr.Button(interactive=True),
|
|
|
305 |
# --- API-only endpoint for GLB extraction ---
|
306 |
# Explicitly defines state input as JSON for server calls.
|
307 |
api_extract_glb = gr.Interface(
|
308 |
+
fn=lambda state, mesh_simplify, texture_size: extract_glb(
|
309 |
+
state, mesh_simplify, texture_size, gr.Request()
|
310 |
),
|
311 |
inputs=[
|
312 |
+
gr.JSON(label="State Object"), # Expect state as JSON
|
313 |
gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01),
|
314 |
gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
315 |
],
|