Steven18 commited on
Commit
4aa90c9
·
1 Parent(s): 1fb27ce
Files changed (1) hide show
  1. app.py +37 -8
app.py CHANGED
@@ -16,6 +16,7 @@ from PIL import Image
16
  from trellis.pipelines import TrellisImageTo3DPipeline
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
 
19
 
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
@@ -117,11 +118,29 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
117
  """
118
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  @spaces.GPU
122
  def image_to_3d(
123
  image: Image.Image,
124
- multiimages: Union[List[Tuple[Image.Image, str]], List[Any]],
125
  is_multiimage: str,
126
  seed: int,
127
  ss_guidance_strength: float,
@@ -154,8 +173,7 @@ def image_to_3d(
154
  os.makedirs(user_dir, exist_ok=True)
155
  is_multiimage = is_multiimage.lower() == "true"
156
 
157
- if multiimages and not isinstance(multiimages[0], tuple):
158
- multiimages = preprocess_upload_images(multiimages)
159
 
160
  # Run pipeline depending on mode
161
  if not is_multiimage:
@@ -174,7 +192,7 @@ def image_to_3d(
174
  },
175
  )
176
  else:
177
- pil_images = [d[0] for d in multiimages]
178
  outputs = pipeline.run_multi_image(
179
  pil_images,
180
  seed=seed,
@@ -281,6 +299,15 @@ def split_image(image: Image.Image) -> List[Image.Image]:
281
  images.append(Image.fromarray(image[:, s:e+1]))
282
  return [preprocess_image(image) for image in images]
283
 
 
 
 
 
 
 
 
 
 
284
  @spaces.GPU(api_name="quick_generate_glb")
285
  def quick_generate_glb(
286
  image: Image.Image,
@@ -473,8 +500,9 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
473
  examples_multi = gr.Examples(
474
  examples=prepare_multi_example(),
475
  inputs=[image_prompt],
476
- fn=split_image,
477
- outputs=[multiimage_combined],
 
478
  run_on_click=True,
479
  examples_per_page=8,
480
  )
@@ -508,9 +536,10 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
508
  outputs=[multiimage_combined],
509
  )
510
  uploaded_api_images.upload(
511
- fn=preprocess_upload_images,
512
  inputs=[uploaded_api_images],
513
- outputs=[multiimage_combined],
 
514
  preprocess=False,
515
  )
516
 
 
16
  from trellis.pipelines import TrellisImageTo3DPipeline
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
19
+ from collections.abc import Sequence
20
 
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
 
118
  """
119
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
120
 
121
+ def normalize_multiimages(multiimages: Sequence) -> List[Tuple[Image.Image, str]]:
122
+ if not multiimages:
123
+ return []
124
+
125
+ if isinstance(multiimages[0], Image.Image):
126
+ return [
127
+ (pipeline.preprocess_image(img), f"gallery_{i}.png")
128
+ for i, img in enumerate(multiimages)
129
+ ]
130
+
131
+ if isinstance(multiimages[0], tuple):
132
+ return [
133
+ (pipeline.preprocess_image(img), name)
134
+ for img, name in multiimages
135
+ ]
136
+
137
+ return preprocess_upload_images(multiimages)
138
+
139
 
140
  @spaces.GPU
141
  def image_to_3d(
142
  image: Image.Image,
143
+ multiimages: List[Any],
144
  is_multiimage: str,
145
  seed: int,
146
  ss_guidance_strength: float,
 
173
  os.makedirs(user_dir, exist_ok=True)
174
  is_multiimage = is_multiimage.lower() == "true"
175
 
176
+ multiimages = normalize_multiimages(multiimages)
 
177
 
178
  # Run pipeline depending on mode
179
  if not is_multiimage:
 
192
  },
193
  )
194
  else:
195
+ pil_images = [img for img, _ in multiimages]
196
  outputs = pipeline.run_multi_image(
197
  pil_images,
198
  seed=seed,
 
299
  images.append(Image.fromarray(image[:, s:e+1]))
300
  return [preprocess_image(image) for image in images]
301
 
302
+ def _example_to_multi(img: Image.Image):
303
+ imgs = split_image(img)
304
+ return imgs, imgs
305
+
306
+ def _files_to_gallery_and_state(file_list):
307
+ tuples = preprocess_upload_images(file_list)
308
+ gallery_imgs = [img for img, _ in tuples]
309
+ return gallery_imgs, tuples
310
+
311
  @spaces.GPU(api_name="quick_generate_glb")
312
  def quick_generate_glb(
313
  image: Image.Image,
 
500
  examples_multi = gr.Examples(
501
  examples=prepare_multi_example(),
502
  inputs=[image_prompt],
503
+ fn=_example_to_multi,
504
+ outputs=[multiimage_prompt,
505
+ multiimage_combined],
506
  run_on_click=True,
507
  examples_per_page=8,
508
  )
 
536
  outputs=[multiimage_combined],
537
  )
538
  uploaded_api_images.upload(
539
+ fn=_files_to_gallery_and_state,
540
  inputs=[uploaded_api_images],
541
+ outputs=[multiimage_prompt,
542
+ multiimage_combined],
543
  preprocess=False,
544
  )
545