Steven18 commited on
Commit
6672d5c
·
1 Parent(s): bae3b7a

add one click generate glb

Browse files
Files changed (1) hide show
  1. app.py +30 -6
app.py CHANGED
@@ -119,7 +119,7 @@ def image_to_3d(
119
  slat_sampling_steps: int,
120
  multiimage_algo: Literal["multidiffusion", "stochastic"],
121
  req: gr.Request,
122
- ) -> Tuple[dict, str, str]:
123
  """
124
  Convert an image (or multiple images) into a 3D model and return its state and video.
125
 
@@ -137,7 +137,6 @@ def image_to_3d(
137
  Returns:
138
  dict: The information of the generated 3D model.
139
  str: The path to the video of the 3D model.
140
- str: serialized JSON of state
141
 
142
  """
143
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -188,7 +187,7 @@ def image_to_3d(
188
  # Pack state for downstream use
189
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
190
  torch.cuda.empty_cache()
191
- return state, video_path, json.dumps(state)
192
 
193
 
194
 
@@ -321,11 +320,15 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
321
 
322
  with gr.Row():
323
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
324
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
 
 
 
 
 
325
 
326
  is_multiimage = gr.State(False)
327
  output_buf = gr.State()
328
- state_textbox = gr.Textbox(visible=False, label="Serialized State")
329
 
330
  # Example images at the bottom of the page
331
  with gr.Row() as single_image_example:
@@ -385,7 +388,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
385
  ss_guidance_strength, ss_sampling_steps,
386
  slat_guidance_strength, slat_sampling_steps, multiimage_algo
387
  ],
388
- outputs=[output_buf, video_output, state_textbox], # multi output
389
  ).then(
390
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
391
  outputs=[extract_glb_btn, extract_gs_btn],
@@ -418,6 +421,27 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
418
  lambda: gr.Button(interactive=False),
419
  outputs=[download_glb],
420
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
 
423
  # Launch the Gradio app
 
119
  slat_sampling_steps: int,
120
  multiimage_algo: Literal["multidiffusion", "stochastic"],
121
  req: gr.Request,
122
+ ) -> Tuple[dict, str]:
123
  """
124
  Convert an image (or multiple images) into a 3D model and return its state and video.
125
 
 
137
  Returns:
138
  dict: The information of the generated 3D model.
139
  str: The path to the video of the 3D model.
 
140
 
141
  """
142
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
187
  # Pack state for downstream use
188
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
189
  torch.cuda.empty_cache()
190
+ return state, video_path
191
 
192
 
193
 
 
320
 
321
  with gr.Row():
322
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
323
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
324
+
325
+ with gr.Accordion("Quick GLB from Image", open=False):
326
+ generate_glb_btn = gr.Button("Upload and Generate GLB Automatically")
327
+ quick_video = gr.Video(label="Quick 3D Preview", autoplay=True, loop=True)
328
+ quick_glb_download = gr.DownloadButton(label="Download GLB", interactive=False)
329
 
330
  is_multiimage = gr.State(False)
331
  output_buf = gr.State()
 
332
 
333
  # Example images at the bottom of the page
334
  with gr.Row() as single_image_example:
 
388
  ss_guidance_strength, ss_sampling_steps,
389
  slat_guidance_strength, slat_sampling_steps, multiimage_algo
390
  ],
391
+ outputs=[output_buf, video_output],
392
  ).then(
393
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
394
  outputs=[extract_glb_btn, extract_gs_btn],
 
421
  lambda: gr.Button(interactive=False),
422
  outputs=[download_glb],
423
  )
424
+
425
+ generate_glb_btn.click(
426
+ lambda: get_seed(True, 0),
427
+ outputs=[seed]
428
+ ).then(
429
+ image_to_3d,
430
+ inputs=[
431
+ image_prompt,
432
+ gr.State([]),
433
+ gr.State(False),
434
+ seed,
435
+ gr.State(7.5), gr.State(12),
436
+ gr.State(3.0), gr.State(12),
437
+ gr.State("stochastic")
438
+ ],
439
+ outputs=[output_buf, quick_video],
440
+ ).then(
441
+ extract_glb,
442
+ inputs=[output_buf, mesh_simplify, texture_size],
443
+ outputs=[model_output, quick_glb_download]
444
+ )
445
 
446
 
447
  # Launch the Gradio app