File size: 4,416 Bytes
0964e01
be383f9
6d250b3
be383f9
 
 
 
ad63d32
ef1530a
abdc26c
 
ef1530a
ff14b75
 
 
 
0b2312f
0c04a17
be383f9
 
 
7e70d08
6d250b3
 
be383f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0964e01
6cf650b
f6a2f50
0964e01
 
 
 
 
 
 
 
 
 
 
 
 
f6a2f50
 
be383f9
 
 
a68b44f
be383f9
 
 
 
 
 
 
 
 
 
 
f6a2f50
be383f9
 
 
 
 
 
 
 
 
 
 
466d1ab
be383f9
 
 
 
 
 
 
 
 
 
f6a2f50
 
be383f9
 
 
 
 
 
 
 
76ac4e2
0964e01
76ac4e2
 
 
 
 
 
 
be383f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e70d08
 
 
 
76ac4e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import spaces
import argparse
import gradio as gr
import os
import torch
import trimesh
import sys
from pathlib import Path

pathdir = Path(__file__).parent / 'cube'
sys.path.append(pathdir.as_posix())

# print(__file__)
# print(os.listdir())
# print(os.listdir('cube'))
# print(pathdir.as_posix())

from cube3d.inference.engine import EngineFast, Engine
from pathlib import Path
import uuid
import shutil
from huggingface_hub import snapshot_download


GLOBAL_STATE = {}

def gen_save_folder(max_size=200):
    os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)

    dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]

    if len(dirs) >= max_size:
        oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
        shutil.rmtree(oldest_dir)
        print(f"Removed the oldest folder: {oldest_dir}")

    new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
    os.makedirs(new_folder, exist_ok=True)
    print(f"Created new folder: {new_folder}")

    return new_folder

@spaces.GPU
def handle_text_prompt(input_prompt, variance = 0):
    print(f"prompt: {input_prompt}, variance: {variance}")

    if "engine_fast" not in GLOBAL_STATE: 
        config_path = GLOBAL_STATE["config_path"]
        gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
        shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
        engine_fast = EngineFast(
            config_path,
            gpt_ckpt_path, 
            shape_ckpt_path,
            device=torch.device("cuda"),
        )
        GLOBAL_STATE["engine_fast"] = engine_fast

    top_p = None if variance == 0 else (100 - variance) / 100.0
    mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p)
    # save output
    vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
    save_folder = gen_save_folder()
    output_path = os.path.join(save_folder, "output.glb")
    trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
    return output_path

def build_interface():
    """Build UI for gradio app
    """
    title = "Cube 3D"
    with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
        gr.Markdown(
            f"""
            # {title}
            # Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine!
            """
        )

        with gr.Row():
            with gr.Column(scale=2):
                with gr.Group():
                    input_text_box = gr.Textbox(
                        value=None,
                        label="Prompt",
                        lines=2,
                    )
                    variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance")
                with gr.Row():
                    submit_button = gr.Button("Submit", variant="primary")
            with gr.Column(scale=3):
                model3d = gr.Model3D(
                    label="Output", height="45em", interactive=False
                )
    
        submit_button.click(
            handle_text_prompt,
            inputs=[
                input_text_box,
                variance
            ],
            outputs=[
                model3d
            ]
        )
                
    return interface

def generate(args):
    GLOBAL_STATE["config_path"] = args.config_path
    GLOBAL_STATE["SAVE_DIR"] = args.save_dir
    os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)

    demo = build_interface()
    demo.queue(default_concurrency_limit=1)
    demo.launch()

if __name__=="__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        help="Path to the config file",
        default="cube/cube3d/configs/open_model.yaml",
    )
    parser.add_argument(
        "--gpt_ckpt_path",
        type=str,
        help="Path to the gpt ckpt path",
        default="model_weights/shape_gpt.safetensors",
    )
    parser.add_argument(
        "--shape_ckpt_path",
        type=str,
        help="Path to the shape ckpt path",
        default="model_weights/shape_tokenizer.safetensors",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="gradio_save_dir",
    )

    args = parser.parse_args()
    snapshot_download(
        repo_id="Roblox/cube3d-v0.1",
        local_dir="./model_weights"
    )
    generate(args)