Spaces:
Running
on
L40S
Running
on
L40S
cubev0.5 (#5)
Browse files- updates for cube v0.5 (ce16420814099de27a21e91450b755f4d804d5b8)
- bbox sliders always present (3ee2e412ae88cade005c7324b456a2b0ebd331ac)
- adding model v0.5 yaml (0b489dd3391bd560f0646aa0ab9fb14c5d3f9f09)
- removing shared url on launch (978db53a12658ece2e2b58d96944a3cf394f7a5f)
Co-authored-by: Akash Garg <captaincobb@users.noreply.huggingface.co>
- Dockerfile +1 -1
- app.py +68 -10
- cube/README.md +14 -9
- cube/cube3d/colab_cube3d.ipynb +1 -1
- cube/cube3d/configs/open_model_v0.5.yaml +33 -0
- cube/cube3d/generate.py +18 -2
- cube/cube3d/inference/engine.py +86 -17
- cube/cube3d/inference/utils.py +27 -4
- cube/cube3d/mesh_utils/postprocessing.py +3 -1
- cube/cube3d/model/gpt/dual_stream_roformer.py +5 -0
- cube/cube3d/model/transformers/cache.py +30 -3
- cube/cube3d/model/transformers/dual_stream_attention.py +1 -2
- cube/cube3d/model/transformers/roformer.py +1 -2
- cube/cube3d/vq_vae_encode_decode.py +2 -2
- requirements.txt +2 -1
Dockerfile
CHANGED
@@ -30,6 +30,6 @@ RUN git clone https://github.com/Roblox/cube.git
|
|
30 |
|
31 |
WORKDIR /home/user/app/cube
|
32 |
RUN pip install .[meshlab]
|
33 |
-
RUN huggingface-cli download Roblox/cube3d-v0.
|
34 |
|
35 |
WORKDIR /home/user/app
|
|
|
30 |
|
31 |
WORKDIR /home/user/app/cube
|
32 |
RUN pip install .[meshlab]
|
33 |
+
RUN huggingface-cli download Roblox/cube3d-v0.5 --local-dir ./model_weights
|
34 |
|
35 |
WORKDIR /home/user/app
|
app.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
import trimesh
|
7 |
import sys
|
8 |
from pathlib import Path
|
|
|
9 |
|
10 |
pathdir = Path(__file__).parent / 'cube'
|
11 |
sys.path.append(pathdir.as_posix())
|
@@ -16,11 +17,19 @@ sys.path.append(pathdir.as_posix())
|
|
16 |
# print(pathdir.as_posix())
|
17 |
|
18 |
from cube3d.inference.engine import EngineFast, Engine
|
|
|
19 |
from pathlib import Path
|
20 |
import uuid
|
21 |
import shutil
|
22 |
from huggingface_hub import snapshot_download
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
GLOBAL_STATE = {}
|
26 |
|
@@ -41,8 +50,8 @@ def gen_save_folder(max_size=200):
|
|
41 |
return new_folder
|
42 |
|
43 |
@spaces.GPU
|
44 |
-
def handle_text_prompt(input_prompt,
|
45 |
-
print(f"prompt: {input_prompt},
|
46 |
|
47 |
if "engine_fast" not in GLOBAL_STATE:
|
48 |
config_path = GLOBAL_STATE["config_path"]
|
@@ -56,13 +65,38 @@ def handle_text_prompt(input_prompt, variance = 0):
|
|
56 |
)
|
57 |
GLOBAL_STATE["engine_fast"] = engine_fast
|
58 |
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
# save output
|
62 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
save_folder = gen_save_folder()
|
64 |
output_path = os.path.join(save_folder, "output.glb")
|
65 |
-
trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
|
|
|
66 |
return output_path
|
67 |
|
68 |
def build_interface():
|
@@ -85,7 +119,28 @@ def build_interface():
|
|
85 |
label="Prompt",
|
86 |
lines=2,
|
87 |
)
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
with gr.Row():
|
90 |
submit_button = gr.Button("Submit", variant="primary")
|
91 |
with gr.Column(scale=3):
|
@@ -97,7 +152,11 @@ def build_interface():
|
|
97 |
handle_text_prompt,
|
98 |
inputs=[
|
99 |
input_text_box,
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
],
|
102 |
outputs=[
|
103 |
model3d
|
@@ -105,7 +164,6 @@ def build_interface():
|
|
105 |
)
|
106 |
|
107 |
return interface
|
108 |
-
|
109 |
def generate(args):
|
110 |
GLOBAL_STATE["config_path"] = args.config_path
|
111 |
GLOBAL_STATE["SAVE_DIR"] = args.save_dir
|
@@ -122,7 +180,7 @@ if __name__=="__main__":
|
|
122 |
"--config_path",
|
123 |
type=str,
|
124 |
help="Path to the config file",
|
125 |
-
default="cube/cube3d/configs/
|
126 |
)
|
127 |
parser.add_argument(
|
128 |
"--gpt_ckpt_path",
|
@@ -144,7 +202,7 @@ if __name__=="__main__":
|
|
144 |
|
145 |
args = parser.parse_args()
|
146 |
snapshot_download(
|
147 |
-
repo_id="Roblox/cube3d-v0.
|
148 |
local_dir="./model_weights"
|
149 |
)
|
150 |
generate(args)
|
|
|
6 |
import trimesh
|
7 |
import sys
|
8 |
from pathlib import Path
|
9 |
+
import numpy as np
|
10 |
|
11 |
pathdir = Path(__file__).parent / 'cube'
|
12 |
sys.path.append(pathdir.as_posix())
|
|
|
17 |
# print(pathdir.as_posix())
|
18 |
|
19 |
from cube3d.inference.engine import EngineFast, Engine
|
20 |
+
from cube3d.inference.utils import normalize_bbox
|
21 |
from pathlib import Path
|
22 |
import uuid
|
23 |
import shutil
|
24 |
from huggingface_hub import snapshot_download
|
25 |
|
26 |
+
from cube3d.mesh_utils.postprocessing import (
|
27 |
+
PYMESHLAB_AVAILABLE,
|
28 |
+
create_pymeshset,
|
29 |
+
postprocess_mesh,
|
30 |
+
save_mesh,
|
31 |
+
)
|
32 |
+
|
33 |
|
34 |
GLOBAL_STATE = {}
|
35 |
|
|
|
50 |
return new_folder
|
51 |
|
52 |
@spaces.GPU
|
53 |
+
def handle_text_prompt(input_prompt, use_bbox = True, bbox_x=1.0, bbox_y=1.0, bbox_z=1.0, hi_res=False):
|
54 |
+
print(f"prompt: {input_prompt}, use_bbox: {use_bbox}, bbox_x: {bbox_x}, bbox_y: {bbox_y}, bbox_z: {bbox_z}, hi_res: {hi_res}")
|
55 |
|
56 |
if "engine_fast" not in GLOBAL_STATE:
|
57 |
config_path = GLOBAL_STATE["config_path"]
|
|
|
65 |
)
|
66 |
GLOBAL_STATE["engine_fast"] = engine_fast
|
67 |
|
68 |
+
# Determine bounding box size based on option
|
69 |
+
bbox_size = None
|
70 |
+
if use_bbox:
|
71 |
+
bbox_size = [bbox_x, bbox_y, bbox_z]
|
72 |
+
# For "No Bounding Box", bbox_size remains None
|
73 |
+
|
74 |
+
normalized_bbox = normalize_bbox(bbox_size) if bbox_size is not None else None
|
75 |
+
|
76 |
+
resolution_base = 9.0 if hi_res else 8.0
|
77 |
+
mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=resolution_base, bounding_box_xyz=normalized_bbox)
|
78 |
# save output
|
79 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
80 |
+
|
81 |
+
ms = create_pymeshset(vertices, faces)
|
82 |
+
target_face_num = max(10000, int(faces.shape[0] * 0.1))
|
83 |
+
print(f"Postprocessing mesh to {target_face_num} faces")
|
84 |
+
postprocess_mesh(ms, target_face_num)
|
85 |
+
mesh = ms.current_mesh()
|
86 |
+
vertices = mesh.vertex_matrix()
|
87 |
+
faces = mesh.face_matrix()
|
88 |
+
|
89 |
+
min_extents = np.min(mesh.vertex_matrix(), axis = 0)
|
90 |
+
max_extents = np.max(mesh.vertex_matrix(), axis = 0)
|
91 |
+
|
92 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
|
93 |
+
scene = trimesh.scene.Scene()
|
94 |
+
scene.add_geometry(mesh)
|
95 |
+
|
96 |
save_folder = gen_save_folder()
|
97 |
output_path = os.path.join(save_folder, "output.glb")
|
98 |
+
# trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
|
99 |
+
scene.export(output_path)
|
100 |
return output_path
|
101 |
|
102 |
def build_interface():
|
|
|
119 |
label="Prompt",
|
120 |
lines=2,
|
121 |
)
|
122 |
+
|
123 |
+
use_bbox = gr.Checkbox(label="Use Bbox", value=False)
|
124 |
+
|
125 |
+
with gr.Group() as bbox_group:
|
126 |
+
bbox_x = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Length", interactive=False)
|
127 |
+
bbox_y = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Height", interactive=False)
|
128 |
+
bbox_z = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Depth", interactive=False)
|
129 |
+
|
130 |
+
# Enable/disable bbox sliders based on use_bbox checkbox
|
131 |
+
def toggle_bbox_interactivity(use_bbox):
|
132 |
+
return (
|
133 |
+
gr.Slider(interactive=use_bbox),
|
134 |
+
gr.Slider(interactive=use_bbox),
|
135 |
+
gr.Slider(interactive=use_bbox)
|
136 |
+
)
|
137 |
+
use_bbox.change(
|
138 |
+
toggle_bbox_interactivity,
|
139 |
+
inputs=[use_bbox],
|
140 |
+
outputs=[bbox_x, bbox_y, bbox_z]
|
141 |
+
)
|
142 |
+
|
143 |
+
hi_res = gr.Checkbox(label="Hi-Res", value=False)
|
144 |
with gr.Row():
|
145 |
submit_button = gr.Button("Submit", variant="primary")
|
146 |
with gr.Column(scale=3):
|
|
|
152 |
handle_text_prompt,
|
153 |
inputs=[
|
154 |
input_text_box,
|
155 |
+
use_bbox,
|
156 |
+
bbox_x,
|
157 |
+
bbox_y,
|
158 |
+
bbox_z,
|
159 |
+
hi_res
|
160 |
],
|
161 |
outputs=[
|
162 |
model3d
|
|
|
164 |
)
|
165 |
|
166 |
return interface
|
|
|
167 |
def generate(args):
|
168 |
GLOBAL_STATE["config_path"] = args.config_path
|
169 |
GLOBAL_STATE["SAVE_DIR"] = args.save_dir
|
|
|
180 |
"--config_path",
|
181 |
type=str,
|
182 |
help="Path to the config file",
|
183 |
+
default="cube/cube3d/configs/open_model_v0.5.yaml",
|
184 |
)
|
185 |
parser.add_argument(
|
186 |
"--gpt_ckpt_path",
|
|
|
202 |
|
203 |
args = parser.parse_args()
|
204 |
snapshot_download(
|
205 |
+
repo_id="Roblox/cube3d-v0.5",
|
206 |
local_dir="./model_weights"
|
207 |
)
|
208 |
generate(args)
|
cube/README.md
CHANGED
@@ -6,9 +6,10 @@
|
|
6 |
|
7 |
<div align="center">
|
8 |
<a href=https://corp.roblox.com/newsroom/2025/03/introducing-roblox-cube target="_blank"><img src=https://img.shields.io/badge/Roblox-Blog-000000.svg?logo=Roblox height=22px></a>
|
9 |
-
<a href=https://huggingface.co/Roblox/cube3d-0.1 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-d96902.svg height=22px></a>
|
10 |
<a href=https://arxiv.org/abs/2503.15475 target="_blank"><img src=https://img.shields.io/badge/ArXiv-Report-b5212f.svg?logo=arxiv height=22px></a>
|
11 |
-
<a href=https://
|
|
|
|
|
12 |
</div>
|
13 |
|
14 |
|
@@ -27,7 +28,10 @@ towards this vision, we hope to engage others in the research community to addre
|
|
27 |
|
28 |
Cube 3D is our first step towards 3D intelligence, which involves a shape tokenizer and a text-to-shape generation model. We are unlocking the power of generating 3D assets and enhancing creativity for all artists. Our latest version of Cube 3D is now accessible to individuals, creators, researchers and businesses of all sizes so that they can experiment, innovate and scale their ideas responsibly. This release includes model weights and starting code for using our text-to-shape model to create 3D assets.
|
29 |
|
30 |
-
### Try it out on
|
|
|
|
|
|
|
31 |
|
32 |
### Install Requirements
|
33 |
|
@@ -41,6 +45,8 @@ pip install -e .[meshlab]
|
|
41 |
|
42 |
> **CUDA**: If you are using a Windows machine, you may need to install the [CUDA](https://developer.nvidia.com/cuda-downloads) toolkit as well as `torch` with cuda support via `pip install torch --index-url https://download.pytorch.org/whl/cu124 --force-reinstall`
|
43 |
|
|
|
|
|
44 |
> **Note**: `[meshlab]` is an optional dependency and can be removed by simply running `pip install -e .` for better compatibility but mesh simplification will be disabled.
|
45 |
|
46 |
### Download Models from Huggingface 🤗
|
@@ -75,7 +81,7 @@ and save it as `turntable.gif` in the specified `output` directory.
|
|
75 |
|
76 |
We provide several example output objects and their corresponding text prompts in the `examples` folder.
|
77 |
|
78 |
-
> **Note**: You must have Blender installed and available in your system's PATH to render the turntable GIF. You can download it from [Blender's official website](https://www.blender.org/). Ensure that the Blender executable is accessible from the command line.
|
79 |
|
80 |
> **Note**: If shape decoding is slow, you can try to specify a lower resolution using the `--resolution-base` flag. A lower resolution will create a coarser and lower quality output mesh but faster decoding. Values between 4.0 and 9.0 are recommended.
|
81 |
|
@@ -118,16 +124,15 @@ engine_fast = EngineFast( # only supported on CUDA devices, replace with Engine
|
|
118 |
config_path,
|
119 |
gpt_ckpt_path,
|
120 |
shape_ckpt_path,
|
121 |
-
device=torch.device("cuda"),
|
122 |
)
|
123 |
|
124 |
# inference
|
125 |
input_prompt = "A pair of noise-canceling headphones"
|
126 |
# NOTE: Reduce `resolution_base` for faster inference and lower VRAM usage
|
127 |
-
# The `
|
128 |
-
#
|
129 |
-
|
130 |
-
mesh_v_f = engine_fast.t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_k=5)
|
131 |
|
132 |
# save output
|
133 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
|
|
6 |
|
7 |
<div align="center">
|
8 |
<a href=https://corp.roblox.com/newsroom/2025/03/introducing-roblox-cube target="_blank"><img src=https://img.shields.io/badge/Roblox-Blog-000000.svg?logo=Roblox height=22px></a>
|
|
|
9 |
<a href=https://arxiv.org/abs/2503.15475 target="_blank"><img src=https://img.shields.io/badge/ArXiv-Report-b5212f.svg?logo=arxiv height=22px></a>
|
10 |
+
<a href=https://huggingface.co/Roblox/cube3d-0.1 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-d96902.svg height=22px></a>
|
11 |
+
<a href=https://huggingface.co/spaces/Roblox/cube3d-interactive target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Demo-blue.svg height=22px></a>
|
12 |
+
<a href=https://colab.research.google.com/drive/1ZvTj49pjDCD_crX5WPZNTAoTTzL6-E5t target="_blank"><img src=https://img.shields.io/badge/Colab-Demo-blue.svg?logo=googlecolab height=22px></a>
|
13 |
</div>
|
14 |
|
15 |
|
|
|
28 |
|
29 |
Cube 3D is our first step towards 3D intelligence, which involves a shape tokenizer and a text-to-shape generation model. We are unlocking the power of generating 3D assets and enhancing creativity for all artists. Our latest version of Cube 3D is now accessible to individuals, creators, researchers and businesses of all sizes so that they can experiment, innovate and scale their ideas responsibly. This release includes model weights and starting code for using our text-to-shape model to create 3D assets.
|
30 |
|
31 |
+
### Try it out on
|
32 |
+
|
33 |
+
- [Google Colab](https://colab.research.google.com/drive/1ZvTj49pjDCD_crX5WPZNTAoTTzL6-E5t)
|
34 |
+
- [Hugging Face Interactive Demo](https://huggingface.co/spaces/Roblox/cube3d-interactive)
|
35 |
|
36 |
### Install Requirements
|
37 |
|
|
|
45 |
|
46 |
> **CUDA**: If you are using a Windows machine, you may need to install the [CUDA](https://developer.nvidia.com/cuda-downloads) toolkit as well as `torch` with cuda support via `pip install torch --index-url https://download.pytorch.org/whl/cu124 --force-reinstall`
|
47 |
|
48 |
+
> **MacOS**: Systems with Apple Silicon or AMD GPUs can leverage the Metal Performance Shaders (MPS) backend for PyTorch.
|
49 |
+
|
50 |
> **Note**: `[meshlab]` is an optional dependency and can be removed by simply running `pip install -e .` for better compatibility but mesh simplification will be disabled.
|
51 |
|
52 |
### Download Models from Huggingface 🤗
|
|
|
81 |
|
82 |
We provide several example output objects and their corresponding text prompts in the `examples` folder.
|
83 |
|
84 |
+
> **Note**: You must have Blender (version >= 4.3) installed and available in your system's PATH to render the turntable GIF. You can download it from [Blender's official website](https://www.blender.org/). Ensure that the Blender executable is accessible from the command line.
|
85 |
|
86 |
> **Note**: If shape decoding is slow, you can try to specify a lower resolution using the `--resolution-base` flag. A lower resolution will create a coarser and lower quality output mesh but faster decoding. Values between 4.0 and 9.0 are recommended.
|
87 |
|
|
|
124 |
config_path,
|
125 |
gpt_ckpt_path,
|
126 |
shape_ckpt_path,
|
127 |
+
device=torch.device("cuda"), # Replace with "mps" on Metal-compatible devices
|
128 |
)
|
129 |
|
130 |
# inference
|
131 |
input_prompt = "A pair of noise-canceling headphones"
|
132 |
# NOTE: Reduce `resolution_base` for faster inference and lower VRAM usage
|
133 |
+
# The `top_p` parameter controls randomness between inferences:
|
134 |
+
# Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.
|
135 |
+
mesh_v_f = engine_fast.t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=0.9)
|
|
|
136 |
|
137 |
# save output
|
138 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
cube/cube3d/colab_cube3d.ipynb
CHANGED
@@ -4345,7 +4345,7 @@
|
|
4345 |
"cell_type": "code",
|
4346 |
"source": [
|
4347 |
"input_prompt = \"vintage couch\"\n",
|
4348 |
-
"# Use a lower resolution_base to
|
4349 |
"mesh_v_f = engine.t2s([input_prompt], use_kv_cache=True, resolution_base=5.0)"
|
4350 |
],
|
4351 |
"metadata": {
|
|
|
4345 |
"cell_type": "code",
|
4346 |
"source": [
|
4347 |
"input_prompt = \"vintage couch\"\n",
|
4348 |
+
"# Use a lower resolution_base to accommodate limited GPU VRAM on Colab notebooks\n",
|
4349 |
"mesh_v_f = engine.t2s([input_prompt], use_kv_cache=True, resolution_base=5.0)"
|
4350 |
],
|
4351 |
"metadata": {
|
cube/cube3d/configs/open_model_v0.5.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gpt_model:
|
2 |
+
n_layer: 23
|
3 |
+
n_single_layer: 1
|
4 |
+
rope_theta: 10000
|
5 |
+
n_head: 12
|
6 |
+
n_embd: 1536
|
7 |
+
bias: true
|
8 |
+
eps: 1.e-6
|
9 |
+
shape_model_vocab_size: 16384
|
10 |
+
text_model_embed_dim: 768
|
11 |
+
use_pooled_text_embed: False
|
12 |
+
shape_model_embed_dim: 32
|
13 |
+
encoder_with_cls_token: true
|
14 |
+
use_bbox: true
|
15 |
+
|
16 |
+
shape_model:
|
17 |
+
encoder_with_cls_token: true
|
18 |
+
num_encoder_latents: 1024
|
19 |
+
num_decoder_latents: 0
|
20 |
+
embed_dim: 32
|
21 |
+
width: 768
|
22 |
+
num_heads: 12
|
23 |
+
out_dim: 1
|
24 |
+
eps: 1.e-6
|
25 |
+
num_freqs: 128
|
26 |
+
point_feats: 3
|
27 |
+
embed_point_feats: false
|
28 |
+
num_encoder_layers: 13
|
29 |
+
encoder_cross_attention_levels: [0, 2, 4, 8]
|
30 |
+
num_decoder_layers: 24
|
31 |
+
num_codes: 16384
|
32 |
+
|
33 |
+
text_model_pretrained_model_name_or_path: "openai/clip-vit-large-patch14"
|
cube/cube3d/generate.py
CHANGED
@@ -5,6 +5,7 @@ import torch
|
|
5 |
import trimesh
|
6 |
|
7 |
from cube3d.inference.engine import Engine, EngineFast
|
|
|
8 |
from cube3d.mesh_utils.postprocessing import (
|
9 |
PYMESHLAB_AVAILABLE,
|
10 |
create_pymeshset,
|
@@ -13,6 +14,7 @@ from cube3d.mesh_utils.postprocessing import (
|
|
13 |
)
|
14 |
from cube3d.renderer import renderer
|
15 |
|
|
|
16 |
def generate_mesh(
|
17 |
engine,
|
18 |
prompt,
|
@@ -21,12 +23,14 @@ def generate_mesh(
|
|
21 |
resolution_base=8.0,
|
22 |
disable_postprocess=False,
|
23 |
top_p=None,
|
|
|
24 |
):
|
25 |
mesh_v_f = engine.t2s(
|
26 |
[prompt],
|
27 |
use_kv_cache=True,
|
28 |
resolution_base=resolution_base,
|
29 |
top_p=top_p,
|
|
|
30 |
)
|
31 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
32 |
obj_path = os.path.join(output_dir, f"{output_name}.obj")
|
@@ -92,6 +96,14 @@ if __name__ == "__main__":
|
|
92 |
default=None,
|
93 |
help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.",
|
94 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
parser.add_argument(
|
96 |
"--render-gif",
|
97 |
help="Render a turntable gif of the mesh",
|
@@ -112,7 +124,7 @@ if __name__ == "__main__":
|
|
112 |
)
|
113 |
args = parser.parse_args()
|
114 |
os.makedirs(args.output_dir, exist_ok=True)
|
115 |
-
device =
|
116 |
print(f"Using device: {device}")
|
117 |
# Initialize engine based on fast_inference flag
|
118 |
if args.fast_inference:
|
@@ -127,7 +139,10 @@ if __name__ == "__main__":
|
|
127 |
engine = Engine(
|
128 |
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
|
129 |
)
|
130 |
-
|
|
|
|
|
|
|
131 |
# Generate meshes based on input source
|
132 |
obj_path = generate_mesh(
|
133 |
engine,
|
@@ -137,6 +152,7 @@ if __name__ == "__main__":
|
|
137 |
args.resolution_base,
|
138 |
args.disable_postprocessing,
|
139 |
args.top_p,
|
|
|
140 |
)
|
141 |
if args.render_gif:
|
142 |
gif_path = renderer.render_turntable(obj_path, args.output_dir)
|
|
|
5 |
import trimesh
|
6 |
|
7 |
from cube3d.inference.engine import Engine, EngineFast
|
8 |
+
from cube3d.inference.utils import normalize_bbox, select_device
|
9 |
from cube3d.mesh_utils.postprocessing import (
|
10 |
PYMESHLAB_AVAILABLE,
|
11 |
create_pymeshset,
|
|
|
14 |
)
|
15 |
from cube3d.renderer import renderer
|
16 |
|
17 |
+
|
18 |
def generate_mesh(
|
19 |
engine,
|
20 |
prompt,
|
|
|
23 |
resolution_base=8.0,
|
24 |
disable_postprocess=False,
|
25 |
top_p=None,
|
26 |
+
bounding_box_xyz=None,
|
27 |
):
|
28 |
mesh_v_f = engine.t2s(
|
29 |
[prompt],
|
30 |
use_kv_cache=True,
|
31 |
resolution_base=resolution_base,
|
32 |
top_p=top_p,
|
33 |
+
bounding_box_xyz=bounding_box_xyz,
|
34 |
)
|
35 |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
36 |
obj_path = os.path.join(output_dir, f"{output_name}.obj")
|
|
|
96 |
default=None,
|
97 |
help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.",
|
98 |
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--bounding_box_xyz",
|
101 |
+
nargs=3,
|
102 |
+
type=float,
|
103 |
+
help="Three float values for x, y, z bounding box",
|
104 |
+
default=None,
|
105 |
+
required=False,
|
106 |
+
)
|
107 |
parser.add_argument(
|
108 |
"--render-gif",
|
109 |
help="Render a turntable gif of the mesh",
|
|
|
124 |
)
|
125 |
args = parser.parse_args()
|
126 |
os.makedirs(args.output_dir, exist_ok=True)
|
127 |
+
device = select_device()
|
128 |
print(f"Using device: {device}")
|
129 |
# Initialize engine based on fast_inference flag
|
130 |
if args.fast_inference:
|
|
|
139 |
engine = Engine(
|
140 |
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
|
141 |
)
|
142 |
+
|
143 |
+
if args.bounding_box_xyz is not None:
|
144 |
+
args.bounding_box_xyz = normalize_bbox(tuple(args.bounding_box_xyz))
|
145 |
+
|
146 |
# Generate meshes based on input source
|
147 |
obj_path = generate_mesh(
|
148 |
engine,
|
|
|
152 |
args.resolution_base,
|
153 |
args.disable_postprocessing,
|
154 |
args.top_p,
|
155 |
+
args.bounding_box_xyz,
|
156 |
)
|
157 |
if args.render_gif:
|
158 |
gif_path = renderer.render_turntable(obj_path, args.output_dir)
|
cube/cube3d/inference/engine.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast
|
@@ -77,12 +79,54 @@ class Engine:
|
|
77 |
self.max_id = self.shape_model.cfg.num_codes
|
78 |
|
79 |
@torch.inference_mode()
|
80 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
"""
|
82 |
Prepares the input embeddings for the model based on the provided prompts and guidance scale.
|
83 |
Args:
|
84 |
prompts (list[str]): A list of prompt strings to be encoded.
|
85 |
guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied.
|
|
|
|
|
|
|
86 |
Returns:
|
87 |
tuple: A tuple containing:
|
88 |
- embed (torch.Tensor): The encoded input embeddings.
|
@@ -94,11 +138,19 @@ class Engine:
|
|
94 |
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
95 |
embed = self.encode_input(prompt_embeds, self.gpt_model.shape_bos_id)
|
96 |
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
if guidance_scale > 0.0:
|
99 |
embed = torch.cat([embed, embed], dim=0)
|
100 |
uncond_embeds = self.run_clip([""] * len(prompts))
|
101 |
-
|
|
|
102 |
|
103 |
return embed, cond
|
104 |
|
@@ -161,6 +213,7 @@ class Engine:
|
|
161 |
use_kv_cache: bool,
|
162 |
guidance_scale: float = 3.0,
|
163 |
top_p: float = None,
|
|
|
164 |
):
|
165 |
"""
|
166 |
Generates text using a GPT model based on the provided prompts.
|
@@ -169,11 +222,14 @@ class Engine:
|
|
169 |
use_kv_cache (bool): Whether to use key-value caching for faster generation.
|
170 |
guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
|
171 |
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
172 |
-
|
|
|
|
|
|
|
173 |
Returns:
|
174 |
torch.Tensor: A tensor containing the generated token IDs.
|
175 |
"""
|
176 |
-
embed, cond = self.prepare_inputs(prompts, guidance_scale)
|
177 |
|
178 |
output_ids = []
|
179 |
|
@@ -267,6 +323,7 @@ class Engine:
|
|
267 |
resolution_base: float = 8.0,
|
268 |
chunk_size: int = 100_000,
|
269 |
top_p: float = None,
|
|
|
270 |
):
|
271 |
"""
|
272 |
Generates a 3D mesh from text prompts using a GPT model and shape decoder.
|
@@ -276,12 +333,17 @@ class Engine:
|
|
276 |
guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
|
277 |
resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
|
278 |
chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
|
279 |
-
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
280 |
-
|
|
|
|
|
|
|
281 |
Returns:
|
282 |
mesh_v_f: The generated 3D mesh vertices and faces.
|
283 |
"""
|
284 |
-
output_ids = self.run_gpt(
|
|
|
|
|
285 |
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
286 |
mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
|
287 |
return mesh_v_f
|
@@ -304,6 +366,10 @@ class EngineFast(Engine):
|
|
304 |
device (torch.device): The device to run the inference on (e.g., CPU or CUDA).
|
305 |
"""
|
306 |
|
|
|
|
|
|
|
|
|
307 |
super().__init__(config_path, gpt_ckpt_path, shape_ckpt_path, device)
|
308 |
|
309 |
# CUDA Graph params
|
@@ -424,11 +490,12 @@ class EngineFast(Engine):
|
|
424 |
)
|
425 |
|
426 |
def run_gpt(
|
427 |
-
self,
|
428 |
-
prompts: list[str],
|
429 |
-
use_kv_cache: bool,
|
430 |
guidance_scale: float = 3.0,
|
431 |
-
top_p: float = None
|
|
|
432 |
):
|
433 |
"""
|
434 |
Runs the GPT model to generate text based on the provided prompts.
|
@@ -437,14 +504,18 @@ class EngineFast(Engine):
|
|
437 |
use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
|
438 |
guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
|
439 |
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
440 |
-
|
|
|
|
|
|
|
|
|
441 |
Returns:
|
442 |
torch.Tensor: A tensor containing the generated output token IDs.
|
443 |
Raises:
|
444 |
AssertionError: If the batch size is greater than 1.
|
445 |
"""
|
446 |
|
447 |
-
embed, cond = self.prepare_inputs(prompts, guidance_scale)
|
448 |
assert len(prompts) == 1, "batch size > 1 not support for EngineFast"
|
449 |
|
450 |
batch_size, input_seq_len, _ = embed.shape
|
@@ -475,9 +546,7 @@ class EngineFast(Engine):
|
|
475 |
next_embed = next_embed.repeat(2, 1, 1)
|
476 |
self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1))
|
477 |
|
478 |
-
for i in tqdm(
|
479 |
-
range(1, self.max_new_tokens), desc=f"generating"
|
480 |
-
):
|
481 |
self._set_curr_pos_id(i)
|
482 |
self.graph.replay()
|
483 |
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
import torch
|
4 |
from tqdm import tqdm
|
5 |
from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast
|
|
|
79 |
self.max_id = self.shape_model.cfg.num_codes
|
80 |
|
81 |
@torch.inference_mode()
|
82 |
+
def prepare_conditions_with_bbox(
|
83 |
+
self,
|
84 |
+
cond: torch.Tensor,
|
85 |
+
bounding_box_tensor: Optional[torch.Tensor] = None,
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Prepares condition embeddings by incorporating bounding box information.
|
89 |
+
|
90 |
+
Concatenates bounding box embeddings to the existing condition tensor if the model
|
91 |
+
supports bounding box projection. If no bounding box is provided, uses zero padding.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
cond (torch.Tensor): The input condition embeddings tensor of shape (B, seq_len, dim).
|
95 |
+
bounding_box_xyz (Optional[torch.Tensor], optional): The size of the bounding box
|
96 |
+
as (x, y, z) dimensions represented as a tensor. If None, uses zero padding for
|
97 |
+
bounding box embeddings.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
torch.Tensor: The condition tensor with bounding box embeddings concatenated along
|
101 |
+
the sequence dimension if bounding box projection is supported, otherwise
|
102 |
+
returns the original condition tensor unchanged.
|
103 |
+
"""
|
104 |
+
if not hasattr(self.gpt_model, "bbox_proj"):
|
105 |
+
return cond
|
106 |
+
|
107 |
+
if bounding_box_tensor is None:
|
108 |
+
B = cond.shape[0]
|
109 |
+
bounding_box_tensor = torch.zeros((B, 3), dtype=cond.dtype, device=self.device)
|
110 |
+
|
111 |
+
bbox_emb = self.gpt_model.bbox_proj(bounding_box_tensor).unsqueeze(dim=1)
|
112 |
+
cond = torch.cat([cond, bbox_emb], dim=1)
|
113 |
+
return cond
|
114 |
+
|
115 |
+
@torch.inference_mode()
|
116 |
+
def prepare_inputs(
|
117 |
+
self,
|
118 |
+
prompts: list[str],
|
119 |
+
guidance_scale: float,
|
120 |
+
bounding_box_xyz: Optional[Tuple[float]] = None,
|
121 |
+
):
|
122 |
"""
|
123 |
Prepares the input embeddings for the model based on the provided prompts and guidance scale.
|
124 |
Args:
|
125 |
prompts (list[str]): A list of prompt strings to be encoded.
|
126 |
guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied.
|
127 |
+
bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
|
128 |
+
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
|
129 |
+
uses default bounding box sizing.
|
130 |
Returns:
|
131 |
tuple: A tuple containing:
|
132 |
- embed (torch.Tensor): The encoded input embeddings.
|
|
|
138 |
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
139 |
embed = self.encode_input(prompt_embeds, self.gpt_model.shape_bos_id)
|
140 |
|
141 |
+
if bounding_box_xyz is not None:
|
142 |
+
cond_bbox = torch.atleast_2d(torch.tensor(bounding_box_xyz)).to(self.device)
|
143 |
+
uncond_bbox = torch.zeros_like(cond_bbox).to(self.device)
|
144 |
+
else:
|
145 |
+
cond_bbox = None
|
146 |
+
uncond_bbox = None
|
147 |
+
|
148 |
+
cond = self.prepare_conditions_with_bbox(prompt_embeds, cond_bbox)
|
149 |
if guidance_scale > 0.0:
|
150 |
embed = torch.cat([embed, embed], dim=0)
|
151 |
uncond_embeds = self.run_clip([""] * len(prompts))
|
152 |
+
uncond = self.prepare_conditions_with_bbox(uncond_embeds, uncond_bbox)
|
153 |
+
cond = torch.cat([cond, uncond], dim=0)
|
154 |
|
155 |
return embed, cond
|
156 |
|
|
|
213 |
use_kv_cache: bool,
|
214 |
guidance_scale: float = 3.0,
|
215 |
top_p: float = None,
|
216 |
+
bounding_box_xyz: Optional[Tuple[float]] = None,
|
217 |
):
|
218 |
"""
|
219 |
Generates text using a GPT model based on the provided prompts.
|
|
|
222 |
use_kv_cache (bool): Whether to use key-value caching for faster generation.
|
223 |
guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
|
224 |
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
225 |
+
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
|
226 |
+
bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
|
227 |
+
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
|
228 |
+
uses default bounding box sizing.
|
229 |
Returns:
|
230 |
torch.Tensor: A tensor containing the generated token IDs.
|
231 |
"""
|
232 |
+
embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
|
233 |
|
234 |
output_ids = []
|
235 |
|
|
|
323 |
resolution_base: float = 8.0,
|
324 |
chunk_size: int = 100_000,
|
325 |
top_p: float = None,
|
326 |
+
bounding_box_xyz: Optional[Tuple[float]] = None,
|
327 |
):
|
328 |
"""
|
329 |
Generates a 3D mesh from text prompts using a GPT model and shape decoder.
|
|
|
333 |
guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
|
334 |
resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
|
335 |
chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
|
336 |
+
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
337 |
+
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
|
338 |
+
bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
|
339 |
+
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
|
340 |
+
uses default bounding box sizing.
|
341 |
Returns:
|
342 |
mesh_v_f: The generated 3D mesh vertices and faces.
|
343 |
"""
|
344 |
+
output_ids = self.run_gpt(
|
345 |
+
prompts, use_kv_cache, guidance_scale, top_p, bounding_box_xyz
|
346 |
+
)
|
347 |
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
348 |
mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
|
349 |
return mesh_v_f
|
|
|
366 |
device (torch.device): The device to run the inference on (e.g., CPU or CUDA).
|
367 |
"""
|
368 |
|
369 |
+
assert (
|
370 |
+
device.type == "cuda"
|
371 |
+
), "EngineFast is only supported on cuda devices, please use Engine on non-cuda devices"
|
372 |
+
|
373 |
super().__init__(config_path, gpt_ckpt_path, shape_ckpt_path, device)
|
374 |
|
375 |
# CUDA Graph params
|
|
|
490 |
)
|
491 |
|
492 |
def run_gpt(
|
493 |
+
self,
|
494 |
+
prompts: list[str],
|
495 |
+
use_kv_cache: bool,
|
496 |
guidance_scale: float = 3.0,
|
497 |
+
top_p: float = None,
|
498 |
+
bounding_box_xyz: Optional[Tuple[float]] = None,
|
499 |
):
|
500 |
"""
|
501 |
Runs the GPT model to generate text based on the provided prompts.
|
|
|
504 |
use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
|
505 |
guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
|
506 |
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
507 |
+
If None, argmax selection is performed. Otherwise, smallest
|
508 |
+
set of tokens with cumulative probability ≥ top_p are kept.
|
509 |
+
bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
|
510 |
+
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
|
511 |
+
uses default bounding box sizing.
|
512 |
Returns:
|
513 |
torch.Tensor: A tensor containing the generated output token IDs.
|
514 |
Raises:
|
515 |
AssertionError: If the batch size is greater than 1.
|
516 |
"""
|
517 |
|
518 |
+
embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
|
519 |
assert len(prompts) == 1, "batch size > 1 not support for EngineFast"
|
520 |
|
521 |
batch_size, input_seq_len, _ = embed.shape
|
|
|
546 |
next_embed = next_embed.repeat(2, 1, 1)
|
547 |
self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1))
|
548 |
|
549 |
+
for i in tqdm(range(1, self.max_new_tokens), desc=f"generating"):
|
|
|
|
|
550 |
self._set_curr_pos_id(i)
|
551 |
self.graph.replay()
|
552 |
|
cube/cube3d/inference/utils.py
CHANGED
@@ -1,10 +1,17 @@
|
|
1 |
import logging
|
2 |
-
from typing import Any, Optional
|
3 |
|
4 |
import torch
|
5 |
from omegaconf import DictConfig, OmegaConf
|
6 |
from safetensors.torch import load_model
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def load_config(cfg_path: str) -> Any:
|
10 |
"""
|
@@ -49,8 +56,24 @@ def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
|
|
49 |
Returns:
|
50 |
None
|
51 |
"""
|
52 |
-
assert ckpt_path.endswith(
|
53 |
-
|
54 |
-
)
|
55 |
|
56 |
load_model(model, ckpt_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
from typing import Any, Optional, Tuple
|
3 |
|
4 |
import torch
|
5 |
from omegaconf import DictConfig, OmegaConf
|
6 |
from safetensors.torch import load_model
|
7 |
|
8 |
+
BOUNDING_BOX_MAX_SIZE = 1.925
|
9 |
+
|
10 |
+
|
11 |
+
def normalize_bbox(bounding_box_xyz: Tuple[float]):
|
12 |
+
max_l = max(bounding_box_xyz)
|
13 |
+
return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]
|
14 |
+
|
15 |
|
16 |
def load_config(cfg_path: str) -> Any:
|
17 |
"""
|
|
|
56 |
Returns:
|
57 |
None
|
58 |
"""
|
59 |
+
assert ckpt_path.endswith(
|
60 |
+
".safetensors"
|
61 |
+
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
|
62 |
|
63 |
load_model(model, ckpt_path)
|
64 |
+
|
65 |
+
|
66 |
+
def select_device() -> Any:
|
67 |
+
"""
|
68 |
+
Selects the appropriate PyTorch device for tensor allocation.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
Any: The `torch.device` object.
|
72 |
+
"""
|
73 |
+
return torch.device(
|
74 |
+
"cuda"
|
75 |
+
if torch.cuda.is_available()
|
76 |
+
else "mps"
|
77 |
+
if torch.backends.mps.is_available()
|
78 |
+
else "cpu"
|
79 |
+
)
|
cube/cube3d/mesh_utils/postprocessing.py
CHANGED
@@ -75,10 +75,12 @@ def save_mesh(ms: pymeshlab.MeshSet, output_path: str):
|
|
75 |
logging.info(f"Mesh saved to {output_path}.")
|
76 |
|
77 |
|
78 |
-
def postprocess_mesh(ms: pymeshlab.MeshSet, target_face_num: int
|
79 |
"""
|
80 |
Postprocess the mesh to the target number of faces.
|
81 |
"""
|
82 |
cleanup(ms)
|
83 |
remove_floaters(ms)
|
84 |
simplify_mesh(ms, target_face_num)
|
|
|
|
|
|
75 |
logging.info(f"Mesh saved to {output_path}.")
|
76 |
|
77 |
|
78 |
+
def postprocess_mesh(ms: pymeshlab.MeshSet, target_face_num: int):
|
79 |
"""
|
80 |
Postprocess the mesh to the target number of faces.
|
81 |
"""
|
82 |
cleanup(ms)
|
83 |
remove_floaters(ms)
|
84 |
simplify_mesh(ms, target_face_num)
|
85 |
+
mesh = ms.current_mesh()
|
86 |
+
return mesh.vertex_matrix(), mesh.face_matrix()
|
cube/cube3d/model/gpt/dual_stream_roformer.py
CHANGED
@@ -34,6 +34,8 @@ class DualStreamRoformer(nn.Module):
|
|
34 |
|
35 |
encoder_with_cls_token: bool = True
|
36 |
|
|
|
|
|
37 |
def __init__(self, cfg: Config) -> None:
|
38 |
"""
|
39 |
Initializes the DualStreamRoFormer model.
|
@@ -108,6 +110,9 @@ class DualStreamRoformer(nn.Module):
|
|
108 |
|
109 |
self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False)
|
110 |
|
|
|
|
|
|
|
111 |
def encode_text(self, text_embed):
|
112 |
"""
|
113 |
Encodes the given text embeddings by projecting them through a linear transformation.
|
|
|
34 |
|
35 |
encoder_with_cls_token: bool = True
|
36 |
|
37 |
+
use_bbox: bool = False
|
38 |
+
|
39 |
def __init__(self, cfg: Config) -> None:
|
40 |
"""
|
41 |
Initializes the DualStreamRoFormer model.
|
|
|
110 |
|
111 |
self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False)
|
112 |
|
113 |
+
if self.cfg.use_bbox:
|
114 |
+
self.bbox_proj = nn.Linear(3, self.cfg.n_embd)
|
115 |
+
|
116 |
def encode_text(self, text_embed):
|
117 |
"""
|
118 |
Encodes the given text embeddings by projecting them through a linear transformation.
|
cube/cube3d/model/transformers/cache.py
CHANGED
@@ -1,9 +1,36 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
|
3 |
import torch
|
4 |
|
5 |
-
|
6 |
@dataclass
|
7 |
class Cache:
|
8 |
key_states: torch.Tensor
|
9 |
value_states: torch.Tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
|
|
2 |
import torch
|
3 |
|
|
|
4 |
@dataclass
|
5 |
class Cache:
|
6 |
key_states: torch.Tensor
|
7 |
value_states: torch.Tensor
|
8 |
+
_supports_index_copy: bool = field(init=False) # For CUDA graph support
|
9 |
+
|
10 |
+
def __post_init__(self):
|
11 |
+
self._supports_index_copy = self._check_index_copy_support()
|
12 |
+
|
13 |
+
def _check_index_copy_support(self) -> bool:
|
14 |
+
"""Verifies support for `index_copy_` on device."""
|
15 |
+
try:
|
16 |
+
device = self.key_states.device
|
17 |
+
dummy = torch.tensor([0, 0], device=device)
|
18 |
+
dummy.index_copy_(0, torch.tensor([0], device=device), torch.tensor([1], device=device))
|
19 |
+
return True
|
20 |
+
except NotImplementedError:
|
21 |
+
return False
|
22 |
+
|
23 |
+
def update(self, curr_pos_id: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> None:
|
24 |
+
"""
|
25 |
+
Updates the cache based on device operator support.
|
26 |
+
Args:
|
27 |
+
curr_pos_id (torch.Tensor): Current position indices for decoding.
|
28 |
+
k (torch.Tensor): The keys to update
|
29 |
+
v (torch.Tensor): The values to update
|
30 |
+
"""
|
31 |
+
if self._supports_index_copy: # CUDA/CPU
|
32 |
+
self.key_states.index_copy_(2, curr_pos_id, k)
|
33 |
+
self.value_states.index_copy_(2, curr_pos_id, v)
|
34 |
+
else: # MPS
|
35 |
+
self.key_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(k)
|
36 |
+
self.value_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(v)
|
cube/cube3d/model/transformers/dual_stream_attention.py
CHANGED
@@ -198,8 +198,7 @@ class DualStreamAttentionWithRotaryEmbedding(nn.Module):
|
|
198 |
kv_cache.value_states[:, :, : k.shape[2], :].copy_(v)
|
199 |
else:
|
200 |
assert curr_pos_id is not None
|
201 |
-
kv_cache.
|
202 |
-
kv_cache.value_states.index_copy_(2, curr_pos_id, v)
|
203 |
k = kv_cache.key_states
|
204 |
v = kv_cache.value_states
|
205 |
|
|
|
198 |
kv_cache.value_states[:, :, : k.shape[2], :].copy_(v)
|
199 |
else:
|
200 |
assert curr_pos_id is not None
|
201 |
+
kv_cache.update(curr_pos_id, k, v)
|
|
|
202 |
k = kv_cache.key_states
|
203 |
v = kv_cache.value_states
|
204 |
|
cube/cube3d/model/transformers/roformer.py
CHANGED
@@ -115,8 +115,7 @@ class SelfAttentionWithRotaryEmbedding(nn.Module):
|
|
115 |
kv_cache.value_states[:, :, : k.shape[2], :].copy_(v)
|
116 |
else:
|
117 |
assert curr_pos_id is not None
|
118 |
-
kv_cache.
|
119 |
-
kv_cache.value_states.index_copy_(2, curr_pos_id, v)
|
120 |
k = kv_cache.key_states
|
121 |
v = kv_cache.value_states
|
122 |
|
|
|
115 |
kv_cache.value_states[:, :, : k.shape[2], :].copy_(v)
|
116 |
else:
|
117 |
assert curr_pos_id is not None
|
118 |
+
kv_cache.update(curr_pos_id, k, v)
|
|
|
119 |
k = kv_cache.key_states
|
120 |
v = kv_cache.value_states
|
121 |
|
cube/cube3d/vq_vae_encode_decode.py
CHANGED
@@ -5,7 +5,7 @@ import numpy as np
|
|
5 |
import torch
|
6 |
import trimesh
|
7 |
|
8 |
-
from cube3d.inference.utils import load_config, load_model_weights, parse_structured
|
9 |
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
|
10 |
|
11 |
MESH_SCALE = 0.96
|
@@ -125,7 +125,7 @@ if __name__ == "__main__":
|
|
125 |
help="Path to save the recovered mesh file.",
|
126 |
)
|
127 |
args = parser.parse_args()
|
128 |
-
device =
|
129 |
logging.info(f"Using device: {device}")
|
130 |
|
131 |
cfg = load_config(args.config_path)
|
|
|
5 |
import torch
|
6 |
import trimesh
|
7 |
|
8 |
+
from cube3d.inference.utils import load_config, load_model_weights, parse_structured, select_device
|
9 |
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
|
10 |
|
11 |
MESH_SCALE = 0.96
|
|
|
125 |
help="Path to save the recovered mesh file.",
|
126 |
)
|
127 |
args = parser.parse_args()
|
128 |
+
device = select_device()
|
129 |
logging.info(f"Using device: {device}")
|
130 |
|
131 |
cfg = load_config(args.config_path)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
gradio
|
2 |
torch
|
3 |
trimesh
|
4 |
-
|
|
|
|
1 |
gradio
|
2 |
torch
|
3 |
trimesh
|
4 |
+
pymeshlab
|
5 |
+
git+https://github.com/Roblox/cube.git
|