hysts HF Staff commited on
Commit
4dd5dd7
·
1 Parent(s): 645df58

Remove multi-image functionality

Browse files
Files changed (1) hide show
  1. app.py +55 -135
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import shlex
3
  import shutil
4
  import subprocess
5
- from typing import Literal
6
 
7
  os.environ["SPCONV_ALGO"] = "native"
8
 
@@ -121,65 +120,44 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
121
  @spaces.GPU
122
  def image_to_3d(
123
  image: Image.Image,
124
- multiimages: list[tuple[Image.Image, str]],
125
- is_multiimage: bool,
126
  seed: int,
127
  ss_guidance_strength: float,
128
  ss_sampling_steps: int,
129
  slat_guidance_strength: float,
130
  slat_sampling_steps: int,
131
- multiimage_algo: Literal["multidiffusion", "stochastic"],
132
  req: gr.Request,
133
  ) -> tuple[dict, str]:
134
  """Convert an image to a 3D model.
135
 
136
  Args:
137
  image (Image.Image): The input image.
138
- multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
139
- is_multiimage (bool): Whether is in multi-image mode.
140
  seed (int): The random seed.
141
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
142
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
143
  slat_guidance_strength (float): The guidance strength for structured latent generation.
144
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
145
- multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
146
 
147
  Returns:
148
  dict: The information of the generated 3D model.
149
  str: The path to the video of the 3D model.
150
  """
151
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
152
- if not is_multiimage:
153
- outputs = pipeline.run(
154
- image,
155
- seed=seed,
156
- formats=["gaussian", "mesh"],
157
- preprocess_image=False,
158
- sparse_structure_sampler_params={
159
- "steps": ss_sampling_steps,
160
- "cfg_strength": ss_guidance_strength,
161
- },
162
- slat_sampler_params={
163
- "steps": slat_sampling_steps,
164
- "cfg_strength": slat_guidance_strength,
165
- },
166
- )
167
- else:
168
- outputs = pipeline.run_multi_image(
169
- [image[0] for image in multiimages],
170
- seed=seed,
171
- formats=["gaussian", "mesh"],
172
- preprocess_image=False,
173
- sparse_structure_sampler_params={
174
- "steps": ss_sampling_steps,
175
- "cfg_strength": ss_guidance_strength,
176
- },
177
- slat_sampler_params={
178
- "steps": slat_sampling_steps,
179
- "cfg_strength": slat_guidance_strength,
180
- },
181
- mode=multiimage_algo,
182
- )
183
  video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
184
  video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
185
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -248,19 +226,6 @@ def prepare_multi_example() -> list[Image.Image]:
248
  return images
249
 
250
 
251
- def split_image(image: Image.Image) -> list[Image.Image]:
252
- """Split an image into multiple views."""
253
- image = np.array(image)
254
- alpha = image[..., 3]
255
- alpha = np.any(alpha > 0, axis=0)
256
- start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
257
- end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
258
- images = []
259
- for s, e in zip(start_pos, end_pos, strict=False):
260
- images.append(Image.fromarray(image[:, s : e + 1]))
261
- return [preprocess_image(image) for image in images]
262
-
263
-
264
  with gr.Blocks(delete_cache=(600, 600)) as demo:
265
  gr.Markdown("""
266
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
@@ -272,51 +237,35 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
272
 
273
  with gr.Row():
274
  with gr.Column():
275
- with gr.Tabs() as input_tabs:
276
- with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
277
- image_prompt = gr.Image(
278
- label="Image Prompt",
279
- format="png",
280
- image_mode="RGBA",
281
- type="pil",
282
- height=300,
283
- )
284
- with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
285
- multiimage_prompt = gr.Gallery(
286
- label="Image Prompt",
287
- format="png",
288
- type="pil",
289
- height=300,
290
- columns=3,
291
- )
292
- gr.Markdown("""
293
- Input different views of the object in separate images.
294
-
295
- *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
296
- """)
297
 
298
  with gr.Accordion(label="Generation Settings", open=False):
299
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
300
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
301
  gr.Markdown("Stage 1: Sparse Structure Generation")
302
  with gr.Row():
303
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
304
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
 
305
  gr.Markdown("Stage 2: Structured Latent Generation")
306
  with gr.Row():
307
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
308
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
309
- multiimage_algo = gr.Radio(
310
- ["stochastic", "multidiffusion"],
311
- label="Multi-image Algorithm",
312
- value="stochastic",
313
- )
314
 
315
  generate_btn = gr.Button("Generate")
316
 
317
  with gr.Accordion(label="GLB Extraction Settings", open=False):
318
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
319
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
320
 
321
  with gr.Row():
322
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
@@ -333,101 +282,72 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
333
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
334
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
335
 
336
- is_multiimage = gr.State(False) # noqa: FBT003
337
  output_buf = gr.State()
338
 
339
- # Example images at the bottom of the page
340
- with gr.Row() as single_image_example:
341
- examples = gr.Examples(
342
- examples=[f"assets/example_image/{image}" for image in os.listdir("assets/example_image")],
343
- inputs=[image_prompt],
344
- fn=preprocess_image,
345
- outputs=[image_prompt],
346
- run_on_click=True,
347
- examples_per_page=64,
348
- )
349
- with gr.Row(visible=False) as multiimage_example:
350
- examples_multi = gr.Examples(
351
- examples=prepare_multi_example(),
352
- inputs=[image_prompt],
353
- fn=split_image,
354
- outputs=[multiimage_prompt],
355
- run_on_click=True,
356
- examples_per_page=8,
357
- )
358
 
359
  # Handlers
360
  demo.load(start_session)
361
  demo.unload(end_session)
362
 
363
- single_image_input_tab.select(
364
- lambda: (False, gr.Row.update(visible=True), gr.Row.update(visible=False)),
365
- outputs=[is_multiimage, single_image_example, multiimage_example],
366
- )
367
- multiimage_input_tab.select(
368
- lambda: (True, gr.Row.update(visible=False), gr.Row.update(visible=True)),
369
- outputs=[is_multiimage, single_image_example, multiimage_example],
370
- )
371
-
372
  image_prompt.upload(
373
- preprocess_image,
374
- inputs=[image_prompt],
375
- outputs=[image_prompt],
376
- )
377
- multiimage_prompt.upload(
378
- preprocess_images,
379
- inputs=[multiimage_prompt],
380
- outputs=[multiimage_prompt],
381
  )
382
 
383
  generate_btn.click(
384
- get_seed,
385
  inputs=[randomize_seed, seed],
386
- outputs=[seed],
387
  ).then(
388
- image_to_3d,
389
  inputs=[
390
  image_prompt,
391
- multiimage_prompt,
392
- is_multiimage,
393
  seed,
394
  ss_guidance_strength,
395
  ss_sampling_steps,
396
  slat_guidance_strength,
397
  slat_sampling_steps,
398
- multiimage_algo,
399
  ],
400
  outputs=[output_buf, video_output],
401
  ).then(
402
- lambda: (gr.Button(interactive=True), gr.Button(interactive=True)),
403
  outputs=[extract_glb_btn, extract_gs_btn],
404
  )
405
 
406
  video_output.clear(
407
- lambda: (gr.Button(interactive=False), gr.Button(interactive=False)),
408
  outputs=[extract_glb_btn, extract_gs_btn],
409
  )
410
 
411
  extract_glb_btn.click(
412
- extract_glb,
413
  inputs=[output_buf, mesh_simplify, texture_size],
414
  outputs=[model_output, download_glb],
415
  ).then(
416
- lambda: gr.Button(interactive=True),
417
  outputs=[download_glb],
418
  )
419
 
420
  extract_gs_btn.click(
421
- extract_gaussian,
422
  inputs=[output_buf],
423
  outputs=[model_output, download_gs],
424
  ).then(
425
- lambda: gr.Button(interactive=True),
426
  outputs=[download_gs],
427
  )
428
 
429
  model_output.clear(
430
- lambda: gr.Button(interactive=False),
431
  outputs=[download_glb],
432
  )
433
 
 
2
  import shlex
3
  import shutil
4
  import subprocess
 
5
 
6
  os.environ["SPCONV_ALGO"] = "native"
7
 
 
120
  @spaces.GPU
121
  def image_to_3d(
122
  image: Image.Image,
 
 
123
  seed: int,
124
  ss_guidance_strength: float,
125
  ss_sampling_steps: int,
126
  slat_guidance_strength: float,
127
  slat_sampling_steps: int,
 
128
  req: gr.Request,
129
  ) -> tuple[dict, str]:
130
  """Convert an image to a 3D model.
131
 
132
  Args:
133
  image (Image.Image): The input image.
 
 
134
  seed (int): The random seed.
135
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
136
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
137
  slat_guidance_strength (float): The guidance strength for structured latent generation.
138
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
 
139
 
140
  Returns:
141
  dict: The information of the generated 3D model.
142
  str: The path to the video of the 3D model.
143
  """
144
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
145
+
146
+ outputs = pipeline.run(
147
+ image,
148
+ seed=seed,
149
+ formats=["gaussian", "mesh"],
150
+ preprocess_image=False,
151
+ sparse_structure_sampler_params={
152
+ "steps": ss_sampling_steps,
153
+ "cfg_strength": ss_guidance_strength,
154
+ },
155
+ slat_sampler_params={
156
+ "steps": slat_sampling_steps,
157
+ "cfg_strength": slat_guidance_strength,
158
+ },
159
+ )
160
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
162
  video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
163
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
226
  return images
227
 
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  with gr.Blocks(delete_cache=(600, 600)) as demo:
230
  gr.Markdown("""
231
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
 
237
 
238
  with gr.Row():
239
  with gr.Column():
240
+ image_prompt = gr.Image(
241
+ label="Image Prompt",
242
+ format="png",
243
+ image_mode="RGBA",
244
+ type="pil",
245
+ height=300,
246
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  with gr.Accordion(label="Generation Settings", open=False):
249
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
250
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
251
  gr.Markdown("Stage 1: Sparse Structure Generation")
252
  with gr.Row():
253
+ ss_guidance_strength = gr.Slider(
254
+ label="Guidance Strength", minimum=0.0, maximum=10.0, step=0.1, value=7.5
255
+ )
256
+ ss_sampling_steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=50, step=1, value=12)
257
  gr.Markdown("Stage 2: Structured Latent Generation")
258
  with gr.Row():
259
+ slat_guidance_strength = gr.Slider(
260
+ label="Guidance Strength", minimum=0.0, maximum=10.0, step=0.1, value=3.0
261
+ )
262
+ slat_sampling_steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=50, step=1, value=12)
 
 
 
263
 
264
  generate_btn = gr.Button("Generate")
265
 
266
  with gr.Accordion(label="GLB Extraction Settings", open=False):
267
+ mesh_simplify = gr.Slider(label="Simplify", minimum=0.9, maximum=0.98, step=0.01, value=0.95)
268
+ texture_size = gr.Slider(label="Texture Size", minimum=512, maximum=2048, step=512, value=1024)
269
 
270
  with gr.Row():
271
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
 
282
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
283
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
284
 
 
285
  output_buf = gr.State()
286
 
287
+ examples = gr.Examples(
288
+ examples=[f"assets/example_image/{image}" for image in os.listdir("assets/example_image")],
289
+ inputs=[image_prompt],
290
+ fn=preprocess_image,
291
+ outputs=[image_prompt],
292
+ run_on_click=True,
293
+ examples_per_page=64,
294
+ )
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  # Handlers
297
  demo.load(start_session)
298
  demo.unload(end_session)
299
 
 
 
 
 
 
 
 
 
 
300
  image_prompt.upload(
301
+ fn=preprocess_image,
302
+ inputs=image_prompt,
303
+ outputs=image_prompt,
 
 
 
 
 
304
  )
305
 
306
  generate_btn.click(
307
+ fn=get_seed,
308
  inputs=[randomize_seed, seed],
309
+ outputs=seed,
310
  ).then(
311
+ fn=image_to_3d,
312
  inputs=[
313
  image_prompt,
 
 
314
  seed,
315
  ss_guidance_strength,
316
  ss_sampling_steps,
317
  slat_guidance_strength,
318
  slat_sampling_steps,
 
319
  ],
320
  outputs=[output_buf, video_output],
321
  ).then(
322
+ fn=lambda: (gr.Button(interactive=True), gr.Button(interactive=True)),
323
  outputs=[extract_glb_btn, extract_gs_btn],
324
  )
325
 
326
  video_output.clear(
327
+ fn=lambda: (gr.Button(interactive=False), gr.Button(interactive=False)),
328
  outputs=[extract_glb_btn, extract_gs_btn],
329
  )
330
 
331
  extract_glb_btn.click(
332
+ fn=extract_glb,
333
  inputs=[output_buf, mesh_simplify, texture_size],
334
  outputs=[model_output, download_glb],
335
  ).then(
336
+ fn=lambda: gr.Button(interactive=True),
337
  outputs=[download_glb],
338
  )
339
 
340
  extract_gs_btn.click(
341
+ fn=extract_gaussian,
342
  inputs=[output_buf],
343
  outputs=[model_output, download_gs],
344
  ).then(
345
+ fn=lambda: gr.Button(interactive=True),
346
  outputs=[download_gs],
347
  )
348
 
349
  model_output.clear(
350
+ fn=lambda: gr.Button(interactive=False),
351
  outputs=[download_glb],
352
  )
353