Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import sys | |
from glob import glob | |
from typing import Any, Union | |
import numpy as np | |
import torch | |
import trimesh | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
from triposg.pipelines.pipeline_triposg_scribble import TripoSGScribblePipeline | |
def run_triposg_scribble( | |
pipe: Any, | |
image_input: Union[str, Image.Image], | |
prompt: str, | |
seed: int, | |
num_inference_steps: int = 16, | |
scribble_confidence: float = 0.4, | |
prompt_confidence: float = 1.0 | |
) -> trimesh.Scene: | |
img_pil = Image.open(image_input).convert("RGB") | |
outputs = pipe( | |
image=img_pil, | |
prompt=prompt, | |
generator=torch.Generator(device=pipe.device).manual_seed(seed), | |
num_inference_steps=num_inference_steps, | |
guidance_scale=0, # this is a CFG-distilled model | |
attention_kwargs={"cross_attention_scale": prompt_confidence, "cross_attention_2_scale": scribble_confidence}, | |
use_flash_decoder=False, # there're some boundary problems when using flash decoder with this model | |
dense_octree_depth=9, hierarchical_octree_depth=9 # 256 resolution for faster inference | |
).samples[0] | |
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1])) | |
return mesh | |
if __name__ == "__main__": | |
device = "cuda" | |
dtype = torch.float16 | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--image-input", type=str, required=True) | |
parser.add_argument("--prompt", type=str, required=True) | |
parser.add_argument("--output-path", type=str, default="./output.glb") | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--num-inference-steps", type=int, default=16) | |
# feel free to tune the scribble confidence, 0.3-0.5 often gives good results for hand-drawn sketches | |
parser.add_argument("--scribble-conf", type=float, default=0.4) | |
parser.add_argument("--prompt-conf", type=float, default=1.0) | |
args = parser.parse_args() | |
# download pretrained weights | |
triposg_scribble_weights_dir = "pretrained_weights/TripoSG-scribble" | |
snapshot_download(repo_id="VAST-AI/TripoSG-scribble", local_dir=triposg_scribble_weights_dir) | |
# init tripoSG pipeline | |
pipe: TripoSGScribblePipeline = TripoSGScribblePipeline.from_pretrained(triposg_scribble_weights_dir).to(device, dtype) | |
# run inference | |
run_triposg_scribble( | |
pipe, | |
image_input=args.image_input, | |
prompt=args.prompt, | |
seed=args.seed, | |
num_inference_steps=args.num_inference_steps, | |
scribble_confidence=args.scribble_conf, | |
prompt_confidence=args.prompt_conf, | |
).export(args.output_path) | |
print("Mesh saved to", args.output_path) | |