hysts HF Staff commited on
Commit
583ab5f
·
1 Parent(s): ba57f56

Remove temp dir and gr.State

Browse files
Files changed (1) hide show
  1. app.py +50 -105
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import pathlib
3
  import shlex
4
- import shutil
5
  import subprocess
 
6
 
7
  os.environ["SPCONV_ALGO"] = "native"
8
 
@@ -29,25 +29,12 @@ from trellis.representations import Gaussian, MeshExtractResult
29
  from trellis.utils import postprocessing_utils, render_utils
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
33
- os.makedirs(TMP_DIR, exist_ok=True)
34
-
35
 
36
  pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
37
  pipeline.cuda()
38
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
39
 
40
 
41
- def start_session(req: gr.Request):
42
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
43
- os.makedirs(user_dir, exist_ok=True)
44
-
45
-
46
- def end_session(req: gr.Request):
47
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
48
- shutil.rmtree(user_dir)
49
-
50
-
51
  def preprocess_image(image: Image.Image) -> Image.Image:
52
  """Preprocess the input image.
53
 
@@ -73,24 +60,26 @@ def preprocess_images(images: list[tuple[Image.Image, str]]) -> list[Image.Image
73
  return [pipeline.preprocess_image(image) for image in images]
74
 
75
 
76
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
77
- return {
78
  "gaussian": {
79
  **gs.init_params,
80
- "_xyz": gs._xyz.cpu().numpy(),
81
- "_features_dc": gs._features_dc.cpu().numpy(),
82
- "_scaling": gs._scaling.cpu().numpy(),
83
- "_rotation": gs._rotation.cpu().numpy(),
84
- "_opacity": gs._opacity.cpu().numpy(),
85
  },
86
  "mesh": {
87
- "vertices": mesh.vertices.cpu().numpy(),
88
- "faces": mesh.faces.cpu().numpy(),
89
  },
90
  }
 
91
 
92
 
93
- def unpack_state(state: dict) -> tuple[Gaussian, EasyDict, str]:
 
94
  gs = Gaussian(
95
  aabb=state["gaussian"]["aabb"],
96
  sh_degree=state["gaussian"]["sh_degree"],
@@ -99,15 +88,15 @@ def unpack_state(state: dict) -> tuple[Gaussian, EasyDict, str]:
99
  opacity_bias=state["gaussian"]["opacity_bias"],
100
  scaling_activation=state["gaussian"]["scaling_activation"],
101
  )
102
- gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device="cuda")
103
- gs._features_dc = torch.tensor(state["gaussian"]["_features_dc"], device="cuda")
104
- gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device="cuda")
105
- gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device="cuda")
106
- gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device="cuda")
107
 
108
  mesh = EasyDict(
109
- vertices=torch.tensor(state["mesh"]["vertices"], device="cuda"),
110
- faces=torch.tensor(state["mesh"]["faces"], device="cuda"),
111
  )
112
 
113
  return gs, mesh
@@ -126,8 +115,7 @@ def image_to_3d(
126
  ss_sampling_steps: int,
127
  slat_guidance_strength: float,
128
  slat_sampling_steps: int,
129
- req: gr.Request,
130
- ) -> tuple[dict, str]:
131
  """Convert an image to a 3D model.
132
 
133
  Args:
@@ -139,11 +127,9 @@ def image_to_3d(
139
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
140
 
141
  Returns:
142
- dict: The information of the generated 3D model.
143
  str: The path to the video of the 3D model.
144
  """
145
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
146
-
147
  outputs = pipeline.run(
148
  image,
149
  seed=seed,
@@ -162,69 +148,55 @@ def image_to_3d(
162
  video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
163
  video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
164
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
165
- video_path = os.path.join(user_dir, "sample.mp4")
166
- imageio.mimsave(video_path, video, fps=15)
167
- state = pack_state(outputs["gaussian"][0], outputs["mesh"][0])
168
- torch.cuda.empty_cache()
169
- return state, video_path
 
 
 
 
170
 
171
 
172
  @spaces.GPU(duration=90)
173
  def extract_glb(
174
- state: dict,
175
  mesh_simplify: float,
176
  texture_size: int,
177
- req: gr.Request,
178
- ) -> tuple[str, str]:
179
  """Extract a GLB file from the 3D model.
180
 
181
  Args:
182
- state (dict): The state of the generated 3D model.
183
  mesh_simplify (float): The mesh simplification factor.
184
  texture_size (int): The texture resolution.
185
 
186
  Returns:
187
  str: The path to the extracted GLB file.
188
  """
189
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
190
- gs, mesh = unpack_state(state)
191
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
192
- glb_path = os.path.join(user_dir, "sample.glb")
193
- glb.export(glb_path)
194
  torch.cuda.empty_cache()
195
- return glb_path, glb_path
 
 
196
 
197
 
198
  @spaces.GPU
199
- def extract_gaussian(state: dict, req: gr.Request) -> tuple[str, str]:
200
  """Extract a Gaussian file from the 3D model.
201
 
202
  Args:
203
- state (dict): The state of the generated 3D model.
204
 
205
  Returns:
206
  str: The path to the extracted Gaussian file.
207
  """
208
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
209
- gs, _ = unpack_state(state)
210
- gaussian_path = os.path.join(user_dir, "sample.ply")
211
- gs.save_ply(gaussian_path)
212
- torch.cuda.empty_cache()
213
- return gaussian_path, gaussian_path
214
-
215
-
216
- def prepare_multi_example() -> list[Image.Image]:
217
- multi_case = list(set([i.split("_")[0] for i in os.listdir("assets/example_multi_image")]))
218
- images = []
219
- for case in multi_case:
220
- _images = []
221
- for i in range(1, 4):
222
- img = Image.open(f"assets/example_multi_image/{case}_{i}.png")
223
- W, H = img.size
224
- img = img.resize((int(W / H * 512), 512))
225
- _images.append(np.array(img))
226
- images.append(Image.fromarray(np.concatenate(_images, axis=1)))
227
- return images
228
 
229
 
230
  with gr.Blocks(delete_cache=(600, 600)) as demo:
@@ -279,11 +251,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
279
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
280
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
281
 
282
- with gr.Row():
283
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
284
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
285
-
286
- output_buf = gr.State()
287
 
288
  examples = gr.Examples(
289
  examples=sorted(pathlib.Path("assets/example_image").glob("*.png")),
@@ -294,10 +262,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
294
  examples_per_page=64,
295
  )
296
 
297
- # Handlers
298
- demo.load(start_session)
299
- demo.unload(end_session)
300
-
301
  image_prompt.upload(
302
  fn=preprocess_image,
303
  inputs=image_prompt,
@@ -318,40 +282,21 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
318
  slat_guidance_strength,
319
  slat_sampling_steps,
320
  ],
321
- outputs=[output_buf, video_output],
322
  ).then(
323
  fn=lambda: (gr.Button(interactive=True), gr.Button(interactive=True)),
324
  outputs=[extract_glb_btn, extract_gs_btn],
 
325
  )
326
 
327
  video_output.clear(
328
  fn=lambda: (gr.Button(interactive=False), gr.Button(interactive=False)),
329
  outputs=[extract_glb_btn, extract_gs_btn],
 
330
  )
331
 
332
- extract_glb_btn.click(
333
- fn=extract_glb,
334
- inputs=[output_buf, mesh_simplify, texture_size],
335
- outputs=[model_output, download_glb],
336
- ).then(
337
- fn=lambda: gr.Button(interactive=True),
338
- outputs=[download_glb],
339
- )
340
-
341
- extract_gs_btn.click(
342
- fn=extract_gaussian,
343
- inputs=[output_buf],
344
- outputs=[model_output, download_gs],
345
- ).then(
346
- fn=lambda: gr.Button(interactive=True),
347
- outputs=[download_gs],
348
- )
349
-
350
- model_output.clear(
351
- fn=lambda: gr.Button(interactive=False),
352
- outputs=[download_glb],
353
- )
354
-
355
 
356
  if __name__ == "__main__":
357
  demo.launch(mcp_server=True)
 
1
  import os
2
  import pathlib
3
  import shlex
 
4
  import subprocess
5
+ import tempfile
6
 
7
  os.environ["SPCONV_ALGO"] = "native"
8
 
 
29
  from trellis.utils import postprocessing_utils, render_utils
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
32
 
33
  pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
34
  pipeline.cuda()
35
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
36
 
37
 
 
 
 
 
 
 
 
 
 
 
38
  def preprocess_image(image: Image.Image) -> Image.Image:
39
  """Preprocess the input image.
40
 
 
60
  return [pipeline.preprocess_image(image) for image in images]
61
 
62
 
63
+ def save_state_to_file(gs: Gaussian, mesh: MeshExtractResult, output_path: str) -> None:
64
+ state = {
65
  "gaussian": {
66
  **gs.init_params,
67
+ "_xyz": gs._xyz,
68
+ "_features_dc": gs._features_dc,
69
+ "_scaling": gs._scaling,
70
+ "_rotation": gs._rotation,
71
+ "_opacity": gs._opacity,
72
  },
73
  "mesh": {
74
+ "vertices": mesh.vertices,
75
+ "faces": mesh.faces,
76
  },
77
  }
78
+ torch.save(state, output_path)
79
 
80
 
81
+ def load_state_from_file(state_path: str) -> tuple[Gaussian, EasyDict]:
82
+ state = torch.load(state_path)
83
  gs = Gaussian(
84
  aabb=state["gaussian"]["aabb"],
85
  sh_degree=state["gaussian"]["sh_degree"],
 
88
  opacity_bias=state["gaussian"]["opacity_bias"],
89
  scaling_activation=state["gaussian"]["scaling_activation"],
90
  )
91
+ gs._xyz = state["gaussian"]["_xyz"]
92
+ gs._features_dc = state["gaussian"]["_features_dc"]
93
+ gs._scaling = state["gaussian"]["_scaling"]
94
+ gs._rotation = state["gaussian"]["_rotation"]
95
+ gs._opacity = state["gaussian"]["_opacity"]
96
 
97
  mesh = EasyDict(
98
+ vertices=state["mesh"]["vertices"],
99
+ faces=state["mesh"]["faces"],
100
  )
101
 
102
  return gs, mesh
 
115
  ss_sampling_steps: int,
116
  slat_guidance_strength: float,
117
  slat_sampling_steps: int,
118
+ ) -> tuple[str, str]:
 
119
  """Convert an image to a 3D model.
120
 
121
  Args:
 
127
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
128
 
129
  Returns:
130
+ str: The path to the pickle file that contains the state of the generated 3D model.
131
  str: The path to the video of the 3D model.
132
  """
 
 
133
  outputs = pipeline.run(
134
  image,
135
  seed=seed,
 
148
  video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
149
  video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
150
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
151
+
152
+ with (
153
+ tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as state_file,
154
+ tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as video_file,
155
+ ):
156
+ save_state_to_file(outputs["gaussian"][0], outputs["mesh"][0], state_file.name)
157
+ torch.cuda.empty_cache()
158
+ imageio.mimsave(video_file.name, video, fps=15)
159
+ return state_file.name, video_file.name
160
 
161
 
162
  @spaces.GPU(duration=90)
163
  def extract_glb(
164
+ state_path: str,
165
  mesh_simplify: float,
166
  texture_size: int,
167
+ ) -> str:
 
168
  """Extract a GLB file from the 3D model.
169
 
170
  Args:
171
+ state_path (str): The path to the pickle file that contains the state of the generated 3D model.
172
  mesh_simplify (float): The mesh simplification factor.
173
  texture_size (int): The texture resolution.
174
 
175
  Returns:
176
  str: The path to the extracted GLB file.
177
  """
178
+ gs, mesh = load_state_from_file(state_path)
 
179
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
 
 
180
  torch.cuda.empty_cache()
181
+ with tempfile.NamedTemporaryFile(suffix=".glb", delete=False) as glb_file:
182
+ glb.export(glb_file.name)
183
+ return glb_file.name
184
 
185
 
186
  @spaces.GPU
187
+ def extract_gaussian(state_path: str) -> str:
188
  """Extract a Gaussian file from the 3D model.
189
 
190
  Args:
191
+ state_path (str): The path to the pickle file that contains the state of the generated 3D model.
192
 
193
  Returns:
194
  str: The path to the extracted Gaussian file.
195
  """
196
+ gs, _ = load_state_from_file(state_path)
197
+ with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as gaussian_file:
198
+ gs.save_ply(gaussian_file.name)
199
+ return gaussian_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
 
202
  with gr.Blocks(delete_cache=(600, 600)) as demo:
 
251
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
252
  model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
253
 
254
+ state_file_path = gr.Textbox(visible=False)
 
 
 
 
255
 
256
  examples = gr.Examples(
257
  examples=sorted(pathlib.Path("assets/example_image").glob("*.png")),
 
262
  examples_per_page=64,
263
  )
264
 
 
 
 
 
265
  image_prompt.upload(
266
  fn=preprocess_image,
267
  inputs=image_prompt,
 
282
  slat_guidance_strength,
283
  slat_sampling_steps,
284
  ],
285
+ outputs=[state_file_path, video_output],
286
  ).then(
287
  fn=lambda: (gr.Button(interactive=True), gr.Button(interactive=True)),
288
  outputs=[extract_glb_btn, extract_gs_btn],
289
+ api_name=False,
290
  )
291
 
292
  video_output.clear(
293
  fn=lambda: (gr.Button(interactive=False), gr.Button(interactive=False)),
294
  outputs=[extract_glb_btn, extract_gs_btn],
295
+ api_name=False,
296
  )
297
 
298
+ extract_glb_btn.click(fn=extract_glb, inputs=[state_file_path, mesh_simplify, texture_size], outputs=model_output)
299
+ extract_gs_btn.click(fn=extract_gaussian, inputs=state_file_path, outputs=model_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  if __name__ == "__main__":
302
  demo.launch(mcp_server=True)