Spaces:
Running
on
L40S
Running
on
L40S
File size: 4,416 Bytes
0964e01 be383f9 6d250b3 be383f9 ad63d32 ef1530a abdc26c ef1530a ff14b75 0b2312f 0c04a17 be383f9 7e70d08 6d250b3 be383f9 0964e01 6cf650b f6a2f50 0964e01 f6a2f50 be383f9 a68b44f be383f9 f6a2f50 be383f9 466d1ab be383f9 f6a2f50 be383f9 76ac4e2 0964e01 76ac4e2 be383f9 7e70d08 76ac4e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import spaces
import argparse
import gradio as gr
import os
import torch
import trimesh
import sys
from pathlib import Path
pathdir = Path(__file__).parent / 'cube'
sys.path.append(pathdir.as_posix())
# print(__file__)
# print(os.listdir())
# print(os.listdir('cube'))
# print(pathdir.as_posix())
from cube3d.inference.engine import EngineFast, Engine
from pathlib import Path
import uuid
import shutil
from huggingface_hub import snapshot_download
GLOBAL_STATE = {}
def gen_save_folder(max_size=200):
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]
if len(dirs) >= max_size:
oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
shutil.rmtree(oldest_dir)
print(f"Removed the oldest folder: {oldest_dir}")
new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
os.makedirs(new_folder, exist_ok=True)
print(f"Created new folder: {new_folder}")
return new_folder
@spaces.GPU
def handle_text_prompt(input_prompt, variance = 0):
print(f"prompt: {input_prompt}, variance: {variance}")
if "engine_fast" not in GLOBAL_STATE:
config_path = GLOBAL_STATE["config_path"]
gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
engine_fast = EngineFast(
config_path,
gpt_ckpt_path,
shape_ckpt_path,
device=torch.device("cuda"),
)
GLOBAL_STATE["engine_fast"] = engine_fast
top_p = None if variance == 0 else (100 - variance) / 100.0
mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p)
# save output
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
save_folder = gen_save_folder()
output_path = os.path.join(save_folder, "output.glb")
trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
return output_path
def build_interface():
"""Build UI for gradio app
"""
title = "Cube 3D"
with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
gr.Markdown(
f"""
# {title}
# Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine!
"""
)
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
input_text_box = gr.Textbox(
value=None,
label="Prompt",
lines=2,
)
variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance")
with gr.Row():
submit_button = gr.Button("Submit", variant="primary")
with gr.Column(scale=3):
model3d = gr.Model3D(
label="Output", height="45em", interactive=False
)
submit_button.click(
handle_text_prompt,
inputs=[
input_text_box,
variance
],
outputs=[
model3d
]
)
return interface
def generate(args):
GLOBAL_STATE["config_path"] = args.config_path
GLOBAL_STATE["SAVE_DIR"] = args.save_dir
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
demo = build_interface()
demo.queue(default_concurrency_limit=1)
demo.launch()
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path",
type=str,
help="Path to the config file",
default="cube/cube3d/configs/open_model.yaml",
)
parser.add_argument(
"--gpt_ckpt_path",
type=str,
help="Path to the gpt ckpt path",
default="model_weights/shape_gpt.safetensors",
)
parser.add_argument(
"--shape_ckpt_path",
type=str,
help="Path to the shape ckpt path",
default="model_weights/shape_tokenizer.safetensors",
)
parser.add_argument(
"--save_dir",
type=str,
default="gradio_save_dir",
)
args = parser.parse_args()
snapshot_download(
repo_id="Roblox/cube3d-v0.1",
local_dir="./model_weights"
)
generate(args)
|