Remove multi-image functionality
Browse files
app.py
CHANGED
@@ -2,7 +2,6 @@ import os
|
|
2 |
import shlex
|
3 |
import shutil
|
4 |
import subprocess
|
5 |
-
from typing import Literal
|
6 |
|
7 |
os.environ["SPCONV_ALGO"] = "native"
|
8 |
|
@@ -121,65 +120,44 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
|
|
121 |
@spaces.GPU
|
122 |
def image_to_3d(
|
123 |
image: Image.Image,
|
124 |
-
multiimages: list[tuple[Image.Image, str]],
|
125 |
-
is_multiimage: bool,
|
126 |
seed: int,
|
127 |
ss_guidance_strength: float,
|
128 |
ss_sampling_steps: int,
|
129 |
slat_guidance_strength: float,
|
130 |
slat_sampling_steps: int,
|
131 |
-
multiimage_algo: Literal["multidiffusion", "stochastic"],
|
132 |
req: gr.Request,
|
133 |
) -> tuple[dict, str]:
|
134 |
"""Convert an image to a 3D model.
|
135 |
|
136 |
Args:
|
137 |
image (Image.Image): The input image.
|
138 |
-
multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
|
139 |
-
is_multiimage (bool): Whether is in multi-image mode.
|
140 |
seed (int): The random seed.
|
141 |
ss_guidance_strength (float): The guidance strength for sparse structure generation.
|
142 |
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
|
143 |
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
144 |
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
145 |
-
multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
|
146 |
|
147 |
Returns:
|
148 |
dict: The information of the generated 3D model.
|
149 |
str: The path to the video of the 3D model.
|
150 |
"""
|
151 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
outputs = pipeline.run_multi_image(
|
169 |
-
[image[0] for image in multiimages],
|
170 |
-
seed=seed,
|
171 |
-
formats=["gaussian", "mesh"],
|
172 |
-
preprocess_image=False,
|
173 |
-
sparse_structure_sampler_params={
|
174 |
-
"steps": ss_sampling_steps,
|
175 |
-
"cfg_strength": ss_guidance_strength,
|
176 |
-
},
|
177 |
-
slat_sampler_params={
|
178 |
-
"steps": slat_sampling_steps,
|
179 |
-
"cfg_strength": slat_guidance_strength,
|
180 |
-
},
|
181 |
-
mode=multiimage_algo,
|
182 |
-
)
|
183 |
video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
|
184 |
video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
|
185 |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
@@ -248,19 +226,6 @@ def prepare_multi_example() -> list[Image.Image]:
|
|
248 |
return images
|
249 |
|
250 |
|
251 |
-
def split_image(image: Image.Image) -> list[Image.Image]:
|
252 |
-
"""Split an image into multiple views."""
|
253 |
-
image = np.array(image)
|
254 |
-
alpha = image[..., 3]
|
255 |
-
alpha = np.any(alpha > 0, axis=0)
|
256 |
-
start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
|
257 |
-
end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
|
258 |
-
images = []
|
259 |
-
for s, e in zip(start_pos, end_pos, strict=False):
|
260 |
-
images.append(Image.fromarray(image[:, s : e + 1]))
|
261 |
-
return [preprocess_image(image) for image in images]
|
262 |
-
|
263 |
-
|
264 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
265 |
gr.Markdown("""
|
266 |
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
@@ -272,51 +237,35 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
272 |
|
273 |
with gr.Row():
|
274 |
with gr.Column():
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
height=300,
|
283 |
-
)
|
284 |
-
with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
|
285 |
-
multiimage_prompt = gr.Gallery(
|
286 |
-
label="Image Prompt",
|
287 |
-
format="png",
|
288 |
-
type="pil",
|
289 |
-
height=300,
|
290 |
-
columns=3,
|
291 |
-
)
|
292 |
-
gr.Markdown("""
|
293 |
-
Input different views of the object in separate images.
|
294 |
-
|
295 |
-
*NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
|
296 |
-
""")
|
297 |
|
298 |
with gr.Accordion(label="Generation Settings", open=False):
|
299 |
-
seed = gr.Slider(
|
300 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
301 |
gr.Markdown("Stage 1: Sparse Structure Generation")
|
302 |
with gr.Row():
|
303 |
-
ss_guidance_strength = gr.Slider(
|
304 |
-
|
|
|
|
|
305 |
gr.Markdown("Stage 2: Structured Latent Generation")
|
306 |
with gr.Row():
|
307 |
-
slat_guidance_strength = gr.Slider(
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
label="Multi-image Algorithm",
|
312 |
-
value="stochastic",
|
313 |
-
)
|
314 |
|
315 |
generate_btn = gr.Button("Generate")
|
316 |
|
317 |
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
318 |
-
mesh_simplify = gr.Slider(0.9, 0.98,
|
319 |
-
texture_size = gr.Slider(
|
320 |
|
321 |
with gr.Row():
|
322 |
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
|
@@ -333,101 +282,72 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
333 |
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
334 |
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
335 |
|
336 |
-
is_multiimage = gr.State(False) # noqa: FBT003
|
337 |
output_buf = gr.State()
|
338 |
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
examples_per_page=64,
|
348 |
-
)
|
349 |
-
with gr.Row(visible=False) as multiimage_example:
|
350 |
-
examples_multi = gr.Examples(
|
351 |
-
examples=prepare_multi_example(),
|
352 |
-
inputs=[image_prompt],
|
353 |
-
fn=split_image,
|
354 |
-
outputs=[multiimage_prompt],
|
355 |
-
run_on_click=True,
|
356 |
-
examples_per_page=8,
|
357 |
-
)
|
358 |
|
359 |
# Handlers
|
360 |
demo.load(start_session)
|
361 |
demo.unload(end_session)
|
362 |
|
363 |
-
single_image_input_tab.select(
|
364 |
-
lambda: (False, gr.Row.update(visible=True), gr.Row.update(visible=False)),
|
365 |
-
outputs=[is_multiimage, single_image_example, multiimage_example],
|
366 |
-
)
|
367 |
-
multiimage_input_tab.select(
|
368 |
-
lambda: (True, gr.Row.update(visible=False), gr.Row.update(visible=True)),
|
369 |
-
outputs=[is_multiimage, single_image_example, multiimage_example],
|
370 |
-
)
|
371 |
-
|
372 |
image_prompt.upload(
|
373 |
-
preprocess_image,
|
374 |
-
inputs=
|
375 |
-
outputs=
|
376 |
-
)
|
377 |
-
multiimage_prompt.upload(
|
378 |
-
preprocess_images,
|
379 |
-
inputs=[multiimage_prompt],
|
380 |
-
outputs=[multiimage_prompt],
|
381 |
)
|
382 |
|
383 |
generate_btn.click(
|
384 |
-
get_seed,
|
385 |
inputs=[randomize_seed, seed],
|
386 |
-
outputs=
|
387 |
).then(
|
388 |
-
image_to_3d,
|
389 |
inputs=[
|
390 |
image_prompt,
|
391 |
-
multiimage_prompt,
|
392 |
-
is_multiimage,
|
393 |
seed,
|
394 |
ss_guidance_strength,
|
395 |
ss_sampling_steps,
|
396 |
slat_guidance_strength,
|
397 |
slat_sampling_steps,
|
398 |
-
multiimage_algo,
|
399 |
],
|
400 |
outputs=[output_buf, video_output],
|
401 |
).then(
|
402 |
-
lambda: (gr.Button(interactive=True), gr.Button(interactive=True)),
|
403 |
outputs=[extract_glb_btn, extract_gs_btn],
|
404 |
)
|
405 |
|
406 |
video_output.clear(
|
407 |
-
lambda: (gr.Button(interactive=False), gr.Button(interactive=False)),
|
408 |
outputs=[extract_glb_btn, extract_gs_btn],
|
409 |
)
|
410 |
|
411 |
extract_glb_btn.click(
|
412 |
-
extract_glb,
|
413 |
inputs=[output_buf, mesh_simplify, texture_size],
|
414 |
outputs=[model_output, download_glb],
|
415 |
).then(
|
416 |
-
lambda: gr.Button(interactive=True),
|
417 |
outputs=[download_glb],
|
418 |
)
|
419 |
|
420 |
extract_gs_btn.click(
|
421 |
-
extract_gaussian,
|
422 |
inputs=[output_buf],
|
423 |
outputs=[model_output, download_gs],
|
424 |
).then(
|
425 |
-
lambda: gr.Button(interactive=True),
|
426 |
outputs=[download_gs],
|
427 |
)
|
428 |
|
429 |
model_output.clear(
|
430 |
-
lambda: gr.Button(interactive=False),
|
431 |
outputs=[download_glb],
|
432 |
)
|
433 |
|
|
|
2 |
import shlex
|
3 |
import shutil
|
4 |
import subprocess
|
|
|
5 |
|
6 |
os.environ["SPCONV_ALGO"] = "native"
|
7 |
|
|
|
120 |
@spaces.GPU
|
121 |
def image_to_3d(
|
122 |
image: Image.Image,
|
|
|
|
|
123 |
seed: int,
|
124 |
ss_guidance_strength: float,
|
125 |
ss_sampling_steps: int,
|
126 |
slat_guidance_strength: float,
|
127 |
slat_sampling_steps: int,
|
|
|
128 |
req: gr.Request,
|
129 |
) -> tuple[dict, str]:
|
130 |
"""Convert an image to a 3D model.
|
131 |
|
132 |
Args:
|
133 |
image (Image.Image): The input image.
|
|
|
|
|
134 |
seed (int): The random seed.
|
135 |
ss_guidance_strength (float): The guidance strength for sparse structure generation.
|
136 |
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
|
137 |
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
138 |
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
|
|
139 |
|
140 |
Returns:
|
141 |
dict: The information of the generated 3D model.
|
142 |
str: The path to the video of the 3D model.
|
143 |
"""
|
144 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
145 |
+
|
146 |
+
outputs = pipeline.run(
|
147 |
+
image,
|
148 |
+
seed=seed,
|
149 |
+
formats=["gaussian", "mesh"],
|
150 |
+
preprocess_image=False,
|
151 |
+
sparse_structure_sampler_params={
|
152 |
+
"steps": ss_sampling_steps,
|
153 |
+
"cfg_strength": ss_guidance_strength,
|
154 |
+
},
|
155 |
+
slat_sampler_params={
|
156 |
+
"steps": slat_sampling_steps,
|
157 |
+
"cfg_strength": slat_guidance_strength,
|
158 |
+
},
|
159 |
+
)
|
160 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
|
162 |
video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
|
163 |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
|
|
226 |
return images
|
227 |
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
230 |
gr.Markdown("""
|
231 |
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
|
|
237 |
|
238 |
with gr.Row():
|
239 |
with gr.Column():
|
240 |
+
image_prompt = gr.Image(
|
241 |
+
label="Image Prompt",
|
242 |
+
format="png",
|
243 |
+
image_mode="RGBA",
|
244 |
+
type="pil",
|
245 |
+
height=300,
|
246 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
with gr.Accordion(label="Generation Settings", open=False):
|
249 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
250 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
251 |
gr.Markdown("Stage 1: Sparse Structure Generation")
|
252 |
with gr.Row():
|
253 |
+
ss_guidance_strength = gr.Slider(
|
254 |
+
label="Guidance Strength", minimum=0.0, maximum=10.0, step=0.1, value=7.5
|
255 |
+
)
|
256 |
+
ss_sampling_steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=50, step=1, value=12)
|
257 |
gr.Markdown("Stage 2: Structured Latent Generation")
|
258 |
with gr.Row():
|
259 |
+
slat_guidance_strength = gr.Slider(
|
260 |
+
label="Guidance Strength", minimum=0.0, maximum=10.0, step=0.1, value=3.0
|
261 |
+
)
|
262 |
+
slat_sampling_steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=50, step=1, value=12)
|
|
|
|
|
|
|
263 |
|
264 |
generate_btn = gr.Button("Generate")
|
265 |
|
266 |
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
267 |
+
mesh_simplify = gr.Slider(label="Simplify", minimum=0.9, maximum=0.98, step=0.01, value=0.95)
|
268 |
+
texture_size = gr.Slider(label="Texture Size", minimum=512, maximum=2048, step=512, value=1024)
|
269 |
|
270 |
with gr.Row():
|
271 |
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
|
|
|
282 |
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
283 |
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
284 |
|
|
|
285 |
output_buf = gr.State()
|
286 |
|
287 |
+
examples = gr.Examples(
|
288 |
+
examples=[f"assets/example_image/{image}" for image in os.listdir("assets/example_image")],
|
289 |
+
inputs=[image_prompt],
|
290 |
+
fn=preprocess_image,
|
291 |
+
outputs=[image_prompt],
|
292 |
+
run_on_click=True,
|
293 |
+
examples_per_page=64,
|
294 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
# Handlers
|
297 |
demo.load(start_session)
|
298 |
demo.unload(end_session)
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
image_prompt.upload(
|
301 |
+
fn=preprocess_image,
|
302 |
+
inputs=image_prompt,
|
303 |
+
outputs=image_prompt,
|
|
|
|
|
|
|
|
|
|
|
304 |
)
|
305 |
|
306 |
generate_btn.click(
|
307 |
+
fn=get_seed,
|
308 |
inputs=[randomize_seed, seed],
|
309 |
+
outputs=seed,
|
310 |
).then(
|
311 |
+
fn=image_to_3d,
|
312 |
inputs=[
|
313 |
image_prompt,
|
|
|
|
|
314 |
seed,
|
315 |
ss_guidance_strength,
|
316 |
ss_sampling_steps,
|
317 |
slat_guidance_strength,
|
318 |
slat_sampling_steps,
|
|
|
319 |
],
|
320 |
outputs=[output_buf, video_output],
|
321 |
).then(
|
322 |
+
fn=lambda: (gr.Button(interactive=True), gr.Button(interactive=True)),
|
323 |
outputs=[extract_glb_btn, extract_gs_btn],
|
324 |
)
|
325 |
|
326 |
video_output.clear(
|
327 |
+
fn=lambda: (gr.Button(interactive=False), gr.Button(interactive=False)),
|
328 |
outputs=[extract_glb_btn, extract_gs_btn],
|
329 |
)
|
330 |
|
331 |
extract_glb_btn.click(
|
332 |
+
fn=extract_glb,
|
333 |
inputs=[output_buf, mesh_simplify, texture_size],
|
334 |
outputs=[model_output, download_glb],
|
335 |
).then(
|
336 |
+
fn=lambda: gr.Button(interactive=True),
|
337 |
outputs=[download_glb],
|
338 |
)
|
339 |
|
340 |
extract_gs_btn.click(
|
341 |
+
fn=extract_gaussian,
|
342 |
inputs=[output_buf],
|
343 |
outputs=[model_output, download_gs],
|
344 |
).then(
|
345 |
+
fn=lambda: gr.Button(interactive=True),
|
346 |
outputs=[download_gs],
|
347 |
)
|
348 |
|
349 |
model_output.clear(
|
350 |
+
fn=lambda: gr.Button(interactive=False),
|
351 |
outputs=[download_glb],
|
352 |
)
|
353 |
|