|
import spaces |
|
import gradio as gr |
|
import glob |
|
import hashlib |
|
from PIL import Image |
|
import os |
|
import shlex |
|
import subprocess |
|
|
|
os.makedirs("./ckpt", exist_ok=True) |
|
|
|
subprocess.call(["wget", "-q", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "-O", "./ckpt/sam_vit_h_4b8939.pth"]) |
|
|
|
subprocess.run( |
|
shlex.split( |
|
"pip install pip==24.0" |
|
) |
|
) |
|
subprocess.run( |
|
shlex.split( |
|
"pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps" |
|
) |
|
) |
|
|
|
from infer_api import InferAPI |
|
|
|
config_canocalize = { |
|
'config_path': './configs/canonicalization-infer.yaml', |
|
} |
|
config_multiview = {} |
|
config_slrm = { |
|
'config_path': './configs/mesh-slrm-infer.yaml' |
|
} |
|
config_refine = {} |
|
|
|
EXAMPLE_IMAGES = glob.glob("./input_cases/*") |
|
EXAMPLE_APOSE_IMAGES = glob.glob("./input_cases_apose/*") |
|
|
|
infer_api = InferAPI(config_canocalize, config_multiview, config_slrm, config_refine) |
|
|
|
_HEADER_ = ''' |
|
<h2><b>[CVPR 2025] StdGEN 🤗 Gradio Demo</b></h2> |
|
This is official demo for our CVPR 2025 paper <a href="">StdGEN: Semantic-Decomposed 3D Character Generation from Single Images</a>. |
|
|
|
Code: <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2411.05738' target='_blank'>ArXiv</a>. |
|
|
|
❗️❗️❗️**Important Notes:** This is only a **PREVIEW** version with **coarse precision geometry and texture** due to limited online resource. We skip some refinement process and perform only color back-projection to clothes and hair. Please refer to GitHub repo for complete version. |
|
1. Refinement stage takes about ~2.5min, and the mesh result may possibly delayed due to the server load, please wait patiently. |
|
|
|
2. You can upload any reference image (with or without background), A-pose images are also supported (white bkg required). If the image has an alpha channel (transparency), background segmentation will be automatically performed. Alternatively, you can pre-segment the background using other tools and upload the result directly. |
|
|
|
3. Real person images generally work well, but note that normals may appear smoother than expected. You can try to use other monocular normal estimation models. |
|
|
|
4. The base human model in the output is uncolored due to potential NSFW concerns. If you need colored results, please refer to the official GitHub repository for instructions. |
|
''' |
|
|
|
_CITE_ = r""" |
|
If StdGEN is helpful, please help to ⭐ the <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub Repo</a>. Thanks! [](https://github.com/hyz317/StdGEN) |
|
--- |
|
📝 **Citation** |
|
If you find our work useful for your research or applications, please cite using this bibtex: |
|
```bibtex |
|
@article{he2024stdgen, |
|
title={StdGEN: Semantic-Decomposed 3D Character Generation from Single Images}, |
|
author={He, Yuze and Zhou, Yanning and Zhao, Wang and Wu, Zhongkai and Xiao, Kaiwen and Yang, Wei and Liu, Yong-Jin and Han, Xiao}, |
|
journal={arXiv preprint arXiv:2411.05738}, |
|
year={2024} |
|
} |
|
``` |
|
📧 **Contact** |
|
If you have any questions, feel free to open a discussion or contact us at <b>hyz22@mails.tsinghua.edu.cn</b>. |
|
""" |
|
|
|
cache_arbitrary = {} |
|
cache_multiview = [ {}, {}, {} ] |
|
cache_slrm = {} |
|
cache_refine = {} |
|
|
|
tmp_path = '/tmp' |
|
|
|
|
|
def arbitrary_to_apose(image, seed): |
|
|
|
image = Image.fromarray(image) |
|
image_hash = str(hashlib.md5(image.tobytes()).hexdigest()) + '_' + str(seed) |
|
if image_hash not in cache_arbitrary: |
|
apose_img = infer_api.genStage1(image, seed) |
|
apose_img.save(f'{tmp_path}/{image_hash}.png') |
|
cache_arbitrary[image_hash] = f'{tmp_path}/{image_hash}.png' |
|
print(f'cached apose image: {image_hash}') |
|
return apose_img |
|
else: |
|
apose_img = Image.open(cache_arbitrary[image_hash]) |
|
print(f'loaded cached apose image: {image_hash}') |
|
return apose_img |
|
|
|
def apose_to_multiview(apose_img, seed): |
|
|
|
apose_img = Image.fromarray(apose_img) |
|
image_hash = str(hashlib.md5(apose_img.tobytes()).hexdigest()) + '_' + str(seed) |
|
if image_hash not in cache_multiview[0]: |
|
results = infer_api.genStage2(apose_img, seed, num_levels=1) |
|
for idx, img in enumerate(results[0]["images"]): |
|
img.save(f'{tmp_path}/{image_hash}_images_{idx}.png') |
|
for idx, img in enumerate(results[0]["normals"]): |
|
img.save(f'{tmp_path}/{image_hash}_normals_{idx}.png') |
|
cache_multiview[0][image_hash] = { |
|
"images": [f'{tmp_path}/{image_hash}_images_{idx}.png' for idx in range(len(results[0]["images"]))], |
|
"normals": [f'{tmp_path}/{image_hash}_normals_{idx}.png' for idx in range(len(results[0]["normals"]))] |
|
} |
|
print(f'cached multiview images: {image_hash}') |
|
return results[0]["images"], image_hash |
|
else: |
|
print(f'loaded cached multiview images: {image_hash}') |
|
return [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]], image_hash |
|
|
|
def multiview_to_mesh(images, image_hash): |
|
if image_hash not in cache_slrm: |
|
mesh_files = infer_api.genStage3(images) |
|
cache_slrm[image_hash] = mesh_files |
|
print(f'cached slrm files: {image_hash}') |
|
else: |
|
mesh_files = cache_slrm[image_hash] |
|
print(f'loaded cached slrm files: {image_hash}') |
|
return *mesh_files, image_hash |
|
|
|
def refine_mesh(mesh1, mesh2, mesh3, seed, image_hash): |
|
apose_img = Image.open(cache_multiview[0][image_hash]["images"][0]) |
|
if image_hash not in cache_refine: |
|
results = infer_api.genStage2(apose_img, seed, num_levels=2) |
|
results[0] = {} |
|
results[0]["images"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]] |
|
results[0]["normals"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["normals"]] |
|
refined = infer_api.genStage4([mesh1, mesh2, mesh3], results) |
|
cache_refine[image_hash] = refined |
|
print(f'cached refined mesh: {image_hash}') |
|
else: |
|
refined = cache_refine[image_hash] |
|
print(f'loaded cached refined mesh: {image_hash}') |
|
|
|
return refined |
|
|
|
with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo: |
|
gr.Markdown(_HEADER_) |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## 1. Reference Image to A-pose Image") |
|
input_image = gr.Image(label="Input Reference Image", type="numpy", width=384, height=384) |
|
gr.Examples( |
|
examples=EXAMPLE_IMAGES, |
|
inputs=input_image, |
|
label="Click to use sample images", |
|
) |
|
seed_input = gr.Number( |
|
label="Seed", |
|
value=52, |
|
precision=0, |
|
interactive=True |
|
) |
|
pose_btn = gr.Button("Convert") |
|
with gr.Column(): |
|
gr.Markdown("## 2. Multi-view Generation") |
|
a_pose_image = gr.Image(label="A-pose Result", type="numpy", width=384, height=384) |
|
gr.Examples( |
|
examples=EXAMPLE_APOSE_IMAGES, |
|
inputs=a_pose_image, |
|
label="Click to use sample A-pose images", |
|
) |
|
seed_input2 = gr.Number( |
|
label="Seed", |
|
value=50, |
|
precision=0, |
|
interactive=True |
|
) |
|
state2 = gr.State(value="") |
|
view_btn = gr.Button("Generate Multi-view Images") |
|
|
|
with gr.Column(): |
|
gr.Markdown("## 3. Semantic-aware Reconstruction") |
|
multiview_gallery = gr.Gallery( |
|
label="Multi-view results", |
|
columns=2, |
|
interactive=False, |
|
height="None" |
|
) |
|
state3 = gr.State(value="") |
|
mesh_btn = gr.Button("Reconstruct") |
|
|
|
with gr.Row(): |
|
mesh_cols = [gr.Model3D(label=f"Mesh {i+1}", interactive=False, height=384) for i in range(3)] |
|
full_mesh = gr.Model3D(label="Whole Mesh", height=384) |
|
refine_btn = gr.Button("Refine") |
|
|
|
gr.Markdown("## 4. Mesh refinement") |
|
with gr.Row(): |
|
refined_meshes = [gr.Model3D(label=f"refined mesh {i+1}", height=384) for i in range(3)] |
|
refined_full_mesh = gr.Model3D(label="refined whole mesh", height=384) |
|
|
|
gr.Markdown(_CITE_) |
|
|
|
|
|
pose_btn.click( |
|
arbitrary_to_apose, |
|
inputs=[input_image, seed_input], |
|
outputs=a_pose_image |
|
) |
|
|
|
view_btn.click( |
|
apose_to_multiview, |
|
inputs=[a_pose_image, seed_input2], |
|
outputs=[multiview_gallery, state2] |
|
) |
|
|
|
mesh_btn.click( |
|
multiview_to_mesh, |
|
inputs=[multiview_gallery, state2], |
|
outputs=[*mesh_cols, full_mesh, state3] |
|
) |
|
|
|
refine_btn.click( |
|
refine_mesh, |
|
inputs=[*mesh_cols, seed_input2, state3], |
|
outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(ssr_mode=False) |