mcp-compatible

#2
by victor HF Staff - opened
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +67 -81
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏢
4
  colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.34.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -33,17 +33,13 @@ def end_session(req: gr.Request):
33
 
34
  def preprocess_image(image: Image.Image) -> Image.Image:
35
  """
36
- Preprocess the input image for 3D generation.
37
-
38
- This function is called when a user uploads an image or selects an example.
39
- It applies background removal and other preprocessing steps necessary for
40
- optimal 3D model generation.
41
 
42
  Args:
43
- image (Image.Image): The input image from the user
44
 
45
  Returns:
46
- Image.Image: The preprocessed image ready for 3D generation
47
  """
48
  processed_image = pipeline.preprocess_image(image)
49
  return processed_image
@@ -51,16 +47,13 @@ def preprocess_image(image: Image.Image) -> Image.Image:
51
 
52
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
53
  """
54
- Preprocess a list of input images for multi-image 3D generation.
55
-
56
- This function is called when users upload multiple images in the gallery.
57
- It processes each image to prepare them for the multi-image 3D generation pipeline.
58
 
59
  Args:
60
- images (List[Tuple[Image.Image, str]]): The input images from the gallery
61
 
62
  Returns:
63
- List[Image.Image]: The preprocessed images ready for 3D generation
64
  """
65
  images = [image[0] for image in images]
66
  processed_images = [pipeline.preprocess_image(image) for image in images]
@@ -109,23 +102,13 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
109
 
110
  def get_seed(randomize_seed: bool, seed: int) -> int:
111
  """
112
- Get the random seed for generation.
113
-
114
- This function is called by the generate button to determine whether to use
115
- a random seed or the user-specified seed value.
116
-
117
- Args:
118
- randomize_seed (bool): Whether to generate a random seed
119
- seed (int): The user-specified seed value
120
-
121
- Returns:
122
- int: The seed to use for generation
123
  """
124
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
125
 
126
 
127
- @spaces.GPU(duration=120)
128
- def generate_and_extract_glb(
129
  image: Image.Image,
130
  multiimages: List[Tuple[Image.Image, str]],
131
  is_multiimage: bool,
@@ -135,12 +118,10 @@ def generate_and_extract_glb(
135
  slat_guidance_strength: float,
136
  slat_sampling_steps: int,
137
  multiimage_algo: Literal["multidiffusion", "stochastic"],
138
- mesh_simplify: float,
139
- texture_size: int,
140
  req: gr.Request,
141
- ) -> Tuple[dict, str, str, str]:
142
  """
143
- Convert an image to a 3D model and extract GLB file.
144
 
145
  Args:
146
  image (Image.Image): The input image.
@@ -152,18 +133,12 @@ def generate_and_extract_glb(
152
  slat_guidance_strength (float): The guidance strength for structured latent generation.
153
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
154
  multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
155
- mesh_simplify (float): The mesh simplification factor.
156
- texture_size (int): The texture resolution.
157
 
158
  Returns:
159
  dict: The information of the generated 3D model.
160
  str: The path to the video of the 3D model.
161
- str: The path to the extracted GLB file.
162
- str: The path to the extracted GLB file (for download).
163
  """
164
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
165
-
166
- # Generate 3D model
167
  if not is_multiimage:
168
  outputs = pipeline.run(
169
  image,
@@ -195,43 +170,53 @@ def generate_and_extract_glb(
195
  },
196
  mode=multiimage_algo,
197
  )
198
-
199
- # Render video
200
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
201
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
202
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
203
  video_path = os.path.join(user_dir, 'sample.mp4')
204
  imageio.mimsave(video_path, video, fps=15)
205
-
206
- # Extract GLB
207
- gs = outputs['gaussian'][0]
208
- mesh = outputs['mesh'][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
210
  glb_path = os.path.join(user_dir, 'sample.glb')
211
  glb.export(glb_path)
212
-
213
- # Pack state for optional Gaussian extraction
214
- state = pack_state(gs, mesh)
215
-
216
  torch.cuda.empty_cache()
217
- return state, video_path, glb_path, glb_path
218
 
219
 
220
  @spaces.GPU
221
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
222
  """
223
- Extract a Gaussian splatting file from the generated 3D model.
224
-
225
- This function is called when the user clicks "Extract Gaussian" button.
226
- It converts the 3D model state into a .ply file format containing
227
- Gaussian splatting data for advanced 3D applications.
228
 
229
  Args:
230
- state (dict): The state of the generated 3D model containing Gaussian data
231
- req (gr.Request): Gradio request object for session management
232
 
233
  Returns:
234
- Tuple[str, str]: Paths to the extracted Gaussian file (for display and download)
235
  """
236
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
237
  gs, _ = unpack_state(state)
@@ -257,17 +242,7 @@ def prepare_multi_example() -> List[Image.Image]:
257
 
258
  def split_image(image: Image.Image) -> List[Image.Image]:
259
  """
260
- Split a multi-view image into separate view images.
261
-
262
- This function is called when users select multi-image examples that contain
263
- multiple views in a single concatenated image. It automatically splits them
264
- based on alpha channel boundaries and preprocesses each view.
265
-
266
- Args:
267
- image (Image.Image): A concatenated image containing multiple views
268
-
269
- Returns:
270
- List[Image.Image]: List of individual preprocessed view images
271
  """
272
  image = np.array(image)
273
  alpha = image[..., 3]
@@ -283,9 +258,8 @@ def split_image(image: Image.Image) -> List[Image.Image]:
283
  with gr.Blocks(delete_cache=(600, 600)) as demo:
284
  gr.Markdown("""
285
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
286
- * Upload an image and click "Generate & Extract GLB" to create a 3D asset and automatically extract the GLB file.
287
- * If you want the Gaussian file as well, click "Extract Gaussian" after generation.
288
- * If the image has alpha channel, it will be used as the mask. Otherwise, we use `rembg` to remove the background.
289
 
290
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
291
  """)
@@ -315,13 +289,16 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
315
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
316
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
317
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
 
 
318
 
319
  with gr.Accordion(label="GLB Extraction Settings", open=False):
320
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
321
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
322
-
323
- generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
324
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
325
  gr.Markdown("""
326
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
327
  """)
@@ -389,17 +366,26 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
389
  inputs=[randomize_seed, seed],
390
  outputs=[seed],
391
  ).then(
392
- generate_and_extract_glb,
393
- inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
394
- outputs=[output_buf, video_output, model_output, download_glb],
395
  ).then(
396
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
397
- outputs=[extract_gs_btn, download_glb],
398
  )
399
 
400
  video_output.clear(
401
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]),
402
- outputs=[extract_gs_btn, download_glb, download_gs],
 
 
 
 
 
 
 
 
 
403
  )
404
 
405
  extract_gs_btn.click(
@@ -412,8 +398,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
412
  )
413
 
414
  model_output.clear(
415
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
416
- outputs=[download_glb, download_gs],
417
  )
418
 
419
 
@@ -425,4 +411,4 @@ if __name__ == "__main__":
425
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
426
  except:
427
  pass
428
- demo.launch(mcp_server=True)
 
33
 
34
  def preprocess_image(image: Image.Image) -> Image.Image:
35
  """
36
+ Preprocess the input image.
 
 
 
 
37
 
38
  Args:
39
+ image (Image.Image): The input image.
40
 
41
  Returns:
42
+ Image.Image: The preprocessed image.
43
  """
44
  processed_image = pipeline.preprocess_image(image)
45
  return processed_image
 
47
 
48
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
49
  """
50
+ Preprocess a list of input images.
 
 
 
51
 
52
  Args:
53
+ images (List[Tuple[Image.Image, str]]): The input images.
54
 
55
  Returns:
56
+ List[Image.Image]: The preprocessed images.
57
  """
58
  images = [image[0] for image in images]
59
  processed_images = [pipeline.preprocess_image(image) for image in images]
 
102
 
103
  def get_seed(randomize_seed: bool, seed: int) -> int:
104
  """
105
+ Get the random seed.
 
 
 
 
 
 
 
 
 
 
106
  """
107
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
108
 
109
 
110
+ @spaces.GPU
111
+ def image_to_3d(
112
  image: Image.Image,
113
  multiimages: List[Tuple[Image.Image, str]],
114
  is_multiimage: bool,
 
118
  slat_guidance_strength: float,
119
  slat_sampling_steps: int,
120
  multiimage_algo: Literal["multidiffusion", "stochastic"],
 
 
121
  req: gr.Request,
122
+ ) -> Tuple[dict, str]:
123
  """
124
+ Convert an image to a 3D model.
125
 
126
  Args:
127
  image (Image.Image): The input image.
 
133
  slat_guidance_strength (float): The guidance strength for structured latent generation.
134
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
135
  multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
 
 
136
 
137
  Returns:
138
  dict: The information of the generated 3D model.
139
  str: The path to the video of the 3D model.
 
 
140
  """
141
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
142
  if not is_multiimage:
143
  outputs = pipeline.run(
144
  image,
 
170
  },
171
  mode=multiimage_algo,
172
  )
 
 
173
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
174
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
175
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
176
  video_path = os.path.join(user_dir, 'sample.mp4')
177
  imageio.mimsave(video_path, video, fps=15)
178
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
179
+ torch.cuda.empty_cache()
180
+ return state, video_path
181
+
182
+
183
+ @spaces.GPU(duration=90)
184
+ def extract_glb(
185
+ state: dict,
186
+ mesh_simplify: float,
187
+ texture_size: int,
188
+ req: gr.Request,
189
+ ) -> Tuple[str, str]:
190
+ """
191
+ Extract a GLB file from the 3D model.
192
+
193
+ Args:
194
+ state (dict): The state of the generated 3D model.
195
+ mesh_simplify (float): The mesh simplification factor.
196
+ texture_size (int): The texture resolution.
197
+
198
+ Returns:
199
+ str: The path to the extracted GLB file.
200
+ """
201
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
202
+ gs, mesh = unpack_state(state)
203
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
204
  glb_path = os.path.join(user_dir, 'sample.glb')
205
  glb.export(glb_path)
 
 
 
 
206
  torch.cuda.empty_cache()
207
+ return glb_path, glb_path
208
 
209
 
210
  @spaces.GPU
211
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
212
  """
213
+ Extract a Gaussian file from the 3D model.
 
 
 
 
214
 
215
  Args:
216
+ state (dict): The state of the generated 3D model.
 
217
 
218
  Returns:
219
+ str: The path to the extracted Gaussian file.
220
  """
221
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
222
  gs, _ = unpack_state(state)
 
242
 
243
  def split_image(image: Image.Image) -> List[Image.Image]:
244
  """
245
+ Split an image into multiple views.
 
 
 
 
 
 
 
 
 
 
246
  """
247
  image = np.array(image)
248
  alpha = image[..., 3]
 
258
  with gr.Blocks(delete_cache=(600, 600)) as demo:
259
  gr.Markdown("""
260
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
261
+ * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
262
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
 
263
 
264
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
265
  """)
 
289
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
290
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
291
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
292
+
293
+ generate_btn = gr.Button("Generate")
294
 
295
  with gr.Accordion(label="GLB Extraction Settings", open=False):
296
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
297
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
298
+
299
+ with gr.Row():
300
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
301
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
302
  gr.Markdown("""
303
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
304
  """)
 
366
  inputs=[randomize_seed, seed],
367
  outputs=[seed],
368
  ).then(
369
+ image_to_3d,
370
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
371
+ outputs=[output_buf, video_output],
372
  ).then(
373
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
374
+ outputs=[extract_glb_btn, extract_gs_btn],
375
  )
376
 
377
  video_output.clear(
378
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
379
+ outputs=[extract_glb_btn, extract_gs_btn],
380
+ )
381
+
382
+ extract_glb_btn.click(
383
+ extract_glb,
384
+ inputs=[output_buf, mesh_simplify, texture_size],
385
+ outputs=[model_output, download_glb],
386
+ ).then(
387
+ lambda: gr.Button(interactive=True),
388
+ outputs=[download_glb],
389
  )
390
 
391
  extract_gs_btn.click(
 
398
  )
399
 
400
  model_output.clear(
401
+ lambda: gr.Button(interactive=False),
402
+ outputs=[download_glb],
403
  )
404
 
405
 
 
411
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
412
  except:
413
  pass
414
+ demo.launch()