Steven18 commited on
Commit
af2f852
·
1 Parent(s): f040f37

change ismultiimages logic and add file upload function

Browse files
Files changed (1) hide show
  1. app.py +69 -66
app.py CHANGED
@@ -22,55 +22,6 @@ MAX_SEED = np.iinfo(np.int32).max
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
  os.makedirs(TMP_DIR, exist_ok=True)
24
 
25
- def to_pil_list(
26
- multiimages: List[
27
- Union[
28
- Image.Image,
29
- Tuple[Image.Image, str],
30
- gr.File,
31
- Tuple[gr.File, str],
32
- str, # fallback: plain path
33
- Path
34
- ]
35
- ]
36
- ) -> List[Image.Image]:
37
- """
38
- Convert a heterogeneous `multiimages` list into a homogeneous
39
- `List[Image.Image]`.
40
-
41
- Accepts elements in any of the following forms:
42
- • PIL.Image
43
- • (PIL.Image, caption)
44
- • gr.File (gr.File.name is the temp‑file path)
45
- • (gr.File, caption)
46
- • str / pathlib.Path (direct file path)
47
-
48
- Returns:
49
- List[Image.Image] -- guaranteed PIL images
50
- """
51
- pil_imgs: List[Image.Image] = []
52
-
53
- for item in multiimages:
54
- # Unpack tuple/list, keep first element
55
- if isinstance(item, (tuple, list)):
56
- item = item[0]
57
-
58
- if isinstance(item, Image.Image): # already PIL
59
- pil_imgs.append(item)
60
-
61
- elif hasattr(item, "name"): # gr.File
62
- pil_imgs.append(Image.open(item.name))
63
-
64
- elif isinstance(item, (str, Path)): # file path
65
- pil_imgs.append(Image.open(item))
66
-
67
- else:
68
- raise TypeError(
69
- f"Unsupported element in multiimages: {type(item)}"
70
- )
71
-
72
- return pil_imgs
73
-
74
  def start_session(req: gr.Request):
75
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
76
  os.makedirs(user_dir, exist_ok=True)
@@ -109,6 +60,16 @@ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image
109
  processed_images = [pipeline.preprocess_image(image) for image in images]
110
  return processed_images
111
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
114
  return {
@@ -160,7 +121,7 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
160
  @spaces.GPU
161
  def image_to_3d(
162
  image: Image.Image,
163
- multiimages: List[Tuple[Image.Image, str]],
164
  is_multiimage: str,
165
  seed: int,
166
  ss_guidance_strength: float,
@@ -193,6 +154,9 @@ def image_to_3d(
193
  os.makedirs(user_dir, exist_ok=True)
194
  is_multiimage = is_multiimage.lower() == "true"
195
 
 
 
 
196
  # Run pipeline depending on mode
197
  if not is_multiimage:
198
  outputs = pipeline.run(
@@ -210,7 +174,7 @@ def image_to_3d(
210
  },
211
  )
212
  else:
213
- pil_images = to_pil_list(multiimages)
214
  outputs = pipeline.run_multi_image(
215
  pil_images,
216
  seed=seed,
@@ -386,8 +350,14 @@ def test_for_api_gen(image: Image.Image) -> Image.Image:
386
  """
387
  return image
388
 
389
- def update_is_multiimage(event: SelectData):
390
- return "true" if event.index == 1 else "false"
 
 
 
 
 
 
391
 
392
 
393
  with gr.Blocks(delete_cache=(600, 600)) as demo:
@@ -428,17 +398,20 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
428
  *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
429
  """)
430
 
431
- is_multiimage = gr.Radio(
432
- choices=["true", "false"],
433
- value="false",
434
- label="Use multi-image mode",
435
- visible=True
436
- )
437
 
438
  input_tabs.select(
439
  fn=update_is_multiimage,
440
  outputs=is_multiimage
441
  )
 
 
 
 
 
 
 
 
442
 
443
  with gr.Accordion(label="Generation Settings", open=False):
444
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
@@ -466,7 +439,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
466
  with gr.Row():
467
  quick_generate_glb_btn = gr.Button("Quick Generate GLB")
468
  quick_generate_gs_btn = gr.Button("Quick Generate Gaussian")
469
-
470
  gr.Markdown("""
471
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
472
  """)
@@ -499,7 +472,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
499
  examples=prepare_multi_example(),
500
  inputs=[image_prompt],
501
  fn=split_image,
502
- outputs=[multiimage_prompt],
503
  run_on_click=True,
504
  examples_per_page=8,
505
  )
@@ -522,12 +495,24 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
522
  inputs=[image_prompt],
523
  outputs=[image_prompt],
524
  )
 
 
 
 
 
525
  multiimage_prompt.upload(
526
- preprocess_images,
527
  inputs=[multiimage_prompt],
528
- outputs=[multiimage_prompt],
 
 
 
 
 
 
529
  )
530
 
 
531
  generate_btn.click(
532
  get_seed,
533
  inputs=[randomize_seed, seed],
@@ -535,7 +520,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
535
  ).then(
536
  image_to_3d,
537
  inputs=[
538
- image_prompt, multiimage_prompt, is_multiimage, seed,
539
  ss_guidance_strength, ss_sampling_steps,
540
  slat_guidance_strength, slat_sampling_steps, multiimage_algo
541
  ],
@@ -577,7 +562,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
577
  fn=quick_generate_glb,
578
  inputs=[
579
  image_prompt,
580
- multiimage_prompt,
581
  is_multiimage,
582
  seed,
583
  ss_guidance_strength,
@@ -595,7 +580,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
595
  fn=quick_generate_gs,
596
  inputs=[
597
  image_prompt,
598
- multiimage_prompt,
599
  is_multiimage,
600
  seed,
601
  ss_guidance_strength,
@@ -606,6 +591,24 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
606
  ],
607
  outputs=[model_output, download_gs],
608
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
 
610
 
611
 
 
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
  os.makedirs(TMP_DIR, exist_ok=True)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  os.makedirs(user_dir, exist_ok=True)
 
60
  processed_images = [pipeline.preprocess_image(image) for image in images]
61
  return processed_images
62
 
63
+ def preprocess_upload_images(file_list: List[Any]) -> List[Tuple[Image.Image, str]]:
64
+ """
65
+ Resize all input images to 518x518 and return (image, filename) pairs.
66
+ """
67
+ images = [
68
+ (Image.open(f.name).convert("RGBA").resize((518, 518), Image.Resampling.LANCZOS), f.name)
69
+ for f in file_list
70
+ ]
71
+ return images
72
+
73
 
74
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
75
  return {
 
121
  @spaces.GPU
122
  def image_to_3d(
123
  image: Image.Image,
124
+ multiimages: Union[List[Tuple[Image.Image, str]], List[Any]],
125
  is_multiimage: str,
126
  seed: int,
127
  ss_guidance_strength: float,
 
154
  os.makedirs(user_dir, exist_ok=True)
155
  is_multiimage = is_multiimage.lower() == "true"
156
 
157
+ if multiimages and not isinstance(multiimages[0], tuple):
158
+ multiimages = preprocess_upload_images(multiimages)
159
+
160
  # Run pipeline depending on mode
161
  if not is_multiimage:
162
  outputs = pipeline.run(
 
174
  },
175
  )
176
  else:
177
+ pil_images = [d[0] for d in multiimages]
178
  outputs = pipeline.run_multi_image(
179
  pil_images,
180
  seed=seed,
 
350
  """
351
  return image
352
 
353
+ def update_is_multiimage(event: gr.SelectData):
354
+ return gr.update("true" if event.index == 1 else "false")
355
+
356
+ def toggle_multiimage_visibility(choice: str):
357
+ if choice == "true":
358
+ return gr.update(visible=True), gr.update(visible=False)
359
+ else:
360
+ return gr.update(visible=False), gr.update(visible=False)
361
 
362
 
363
  with gr.Blocks(delete_cache=(600, 600)) as demo:
 
398
  *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
399
  """)
400
 
401
+ is_multiimage = gr.Textbox(value="false", visible=True, interactive=False, label="is_multiimage")
 
 
 
 
 
402
 
403
  input_tabs.select(
404
  fn=update_is_multiimage,
405
  outputs=is_multiimage
406
  )
407
+ uploaded_api_images = gr.Files(file_types=["image"], label="Upload Images")
408
+ multiimage_combined = gr.State()
409
+
410
+ is_multiimage.change(
411
+ fn=toggle_multiimage_visibility,
412
+ inputs=is_multiimage,
413
+ outputs=[uploaded_api_images, multiimage_prompt]
414
+ )
415
 
416
  with gr.Accordion(label="Generation Settings", open=False):
417
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
 
439
  with gr.Row():
440
  quick_generate_glb_btn = gr.Button("Quick Generate GLB")
441
  quick_generate_gs_btn = gr.Button("Quick Generate Gaussian")
442
+
443
  gr.Markdown("""
444
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
445
  """)
 
472
  examples=prepare_multi_example(),
473
  inputs=[image_prompt],
474
  fn=split_image,
475
+ outputs=[multiimage_combined],
476
  run_on_click=True,
477
  examples_per_page=8,
478
  )
 
495
  inputs=[image_prompt],
496
  outputs=[image_prompt],
497
  )
498
+ # multiimage_prompt.upload(
499
+ # preprocess_images,
500
+ # inputs=[multiimage_prompt],
501
+ # outputs=[multiimage_prompt],
502
+ # )
503
  multiimage_prompt.upload(
504
+ fn=preprocess_images,
505
  inputs=[multiimage_prompt],
506
+ outputs=[multiimage_combined],
507
+ )
508
+ uploaded_api_images.upload(
509
+ fn=preprocess_upload_images,
510
+ inputs=[uploaded_api_images],
511
+ outputs=[multiimage_combined],
512
+ preprocess=False,
513
  )
514
 
515
+
516
  generate_btn.click(
517
  get_seed,
518
  inputs=[randomize_seed, seed],
 
520
  ).then(
521
  image_to_3d,
522
  inputs=[
523
+ image_prompt, multiimage_combined, is_multiimage, seed,
524
  ss_guidance_strength, ss_sampling_steps,
525
  slat_guidance_strength, slat_sampling_steps, multiimage_algo
526
  ],
 
562
  fn=quick_generate_glb,
563
  inputs=[
564
  image_prompt,
565
+ multiimage_combined,
566
  is_multiimage,
567
  seed,
568
  ss_guidance_strength,
 
580
  fn=quick_generate_gs,
581
  inputs=[
582
  image_prompt,
583
+ multiimage_combined,
584
  is_multiimage,
585
  seed,
586
  ss_guidance_strength,
 
591
  ],
592
  outputs=[model_output, download_gs],
593
  )
594
+ generate_btn.click(
595
+ fn=image_to_3d,
596
+ inputs=[
597
+ image_prompt, # image: Image.Image
598
+ multiimage_combined, # multiimages: List[UploadedFile] or List[Tuple[Image, str]]
599
+ is_multiimage, # is_multiimage: str
600
+ seed,
601
+ ss_guidance_strength,
602
+ ss_sampling_steps,
603
+ slat_guidance_strength,
604
+ slat_sampling_steps,
605
+ multiimage_algo,
606
+ ],
607
+ outputs=[
608
+ output_buf,
609
+ video_output
610
+ ]
611
+ )
612
 
613
 
614