rgndgn commited on
Commit
efccc85
·
verified ·
1 Parent(s): 96e3b91

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +247 -247
gradio_app.py CHANGED
@@ -1,248 +1,248 @@
1
- import spaces
2
- import os
3
- import tempfile
4
- from typing import Any
5
- import torch
6
- import numpy as np
7
- from PIL import Image
8
- import gradio as gr
9
- import trimesh
10
- from transparent_background import Remover
11
-
12
- import subprocess
13
-
14
- def install_cuda_toolkit():
15
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
16
- CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
17
- CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
18
- subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
19
- subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
20
- subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
21
-
22
- os.environ["CUDA_HOME"] = "/usr/local/cuda"
23
- os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
24
- os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
25
- os.environ["CUDA_HOME"],
26
- "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
27
- )
28
- # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
29
- os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
30
-
31
- install_cuda_toolkit()
32
-
33
- # Import and setup SPAR3D
34
- os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
35
- import spar3d.utils as spar3d_utils
36
- from spar3d.system import SPAR3D
37
-
38
- # Constants
39
- COND_WIDTH = 512
40
- COND_HEIGHT = 512
41
- COND_DISTANCE = 2.2
42
- COND_FOVY = 0.591627
43
- BACKGROUND_COLOR = [0.5, 0.5, 0.5]
44
-
45
- # Initialize models
46
- device = spar3d_utils.get_device()
47
- bg_remover = Remover()
48
- spar3d_model = SPAR3D.from_pretrained(
49
- "stabilityai/stable-point-aware-3d",
50
- config_name="config.yaml",
51
- weight_name="model.safetensors"
52
- ).eval().to(device)
53
-
54
- # Initialize camera parameters
55
- c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
56
- intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
57
- COND_FOVY, COND_HEIGHT, COND_WIDTH
58
- )
59
-
60
- def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.Image:
61
- """Create an RGBA image from RGB image and optional mask."""
62
- rgba_image = rgb_image.convert('RGBA')
63
- if mask is not None:
64
- # Ensure mask is 2D before converting to alpha
65
- if len(mask.shape) > 2:
66
- mask = mask.squeeze()
67
- alpha = Image.fromarray((mask * 255).astype(np.uint8))
68
- rgba_image.putalpha(alpha)
69
- return rgba_image
70
-
71
- def create_batch(input_image: Image.Image) -> dict[str, Any]:
72
- """Prepare image batch for model input."""
73
- # Resize and convert input image to numpy array
74
- resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
75
- img_array = np.array(resized_image).astype(np.float32) / 255.0
76
-
77
- # Extract RGB and alpha channels
78
- if img_array.shape[-1] == 4: # RGBA
79
- rgb = img_array[..., :3]
80
- mask = img_array[..., 3:4]
81
- else: # RGB
82
- rgb = img_array
83
- mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
84
-
85
- # Convert to tensors while keeping channel-last format
86
- rgb = torch.from_numpy(rgb).float() # [H, W, 3]
87
- mask = torch.from_numpy(mask).float() # [H, W, 1]
88
-
89
- # Create background blend (match channel-last format)
90
- bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3) # [1, 1, 3]
91
-
92
- # Blend RGB with background using mask (all in channel-last format)
93
- rgb_cond = torch.lerp(bg_tensor, rgb, mask) # [H, W, 3]
94
-
95
- # Move channels to correct dimension and add batch dimension
96
- # Important: For SPAR3D image tokenizer, we need [B, H, W, C] format
97
- rgb_cond = rgb_cond.unsqueeze(0) # [1, H, W, 3]
98
- mask = mask.unsqueeze(0) # [1, H, W, 1]
99
-
100
- # Create the batch dictionary
101
- batch = {
102
- "rgb_cond": rgb_cond, # [1, H, W, 3]
103
- "mask_cond": mask, # [1, H, W, 1]
104
- "c2w_cond": c2w_cond.unsqueeze(0), # [1, 4, 4]
105
- "intrinsic_cond": intrinsic.unsqueeze(0), # [1, 3, 3]
106
- "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), # [1, 3, 3]
107
- }
108
-
109
- for k, v in batch.items():
110
- print(f"[debug] {k} final shape:", v.shape)
111
-
112
- return batch
113
-
114
- def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
115
- """Process batch through model and generate point cloud."""
116
-
117
- batch_size = batch["rgb_cond"].shape[0]
118
- assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
119
-
120
- # Generate point cloud tokens
121
- try:
122
- cond_tokens = system.forward_pdiff_cond(batch)
123
- except Exception as e:
124
- print("\n[ERROR] Failed in forward_pdiff_cond:")
125
- print(e)
126
- print("\nInput tensor properties:")
127
- print("rgb_cond dtype:", batch["rgb_cond"].dtype)
128
- print("rgb_cond device:", batch["rgb_cond"].device)
129
- print("rgb_cond requires_grad:", batch["rgb_cond"].requires_grad)
130
- raise
131
-
132
- # Sample points
133
- sample_iter = system.sampler.sample_batch_progressive(
134
- batch_size,
135
- cond_tokens,
136
- guidance_scale=guidance_scale,
137
- device=device
138
- )
139
-
140
- # Get final samples
141
- for x in sample_iter:
142
- samples = x["xstart"]
143
-
144
- pc_cond = samples.permute(0, 2, 1).float()
145
-
146
- # Normalize point cloud
147
- pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)
148
-
149
- # Subsample to 512 points
150
- pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]
151
-
152
- return pc_cond
153
-
154
- @spaces.GPU
155
- @torch.inference_mode()
156
- def generate_and_process_3d(image: Image.Image) -> tuple[str | None, Image.Image | None]:
157
- """Generate image from prompt and convert to 3D model."""
158
-
159
- # Generate random seed
160
- seed = np.random.randint(0, np.iinfo(np.int32).max)
161
-
162
- try:
163
- rgb_image = image.convert('RGB')
164
-
165
- # bg_remover returns a PIL Image already, no need to convert
166
- no_bg_image = bg_remover.process(rgb_image)
167
- print(f"[debug] no_bg_image type: {type(no_bg_image)}, mode: {no_bg_image.mode}")
168
-
169
- # Convert to RGBA if not already
170
- rgba_image = no_bg_image.convert('RGBA')
171
- print(f"[debug] rgba_image mode: {rgba_image.mode}")
172
-
173
- processed_image = spar3d_utils.foreground_crop(
174
- rgba_image,
175
- crop_ratio=1.3,
176
- newsize=(COND_WIDTH, COND_HEIGHT),
177
- no_crop=False
178
- )
179
-
180
- # Show the processed image alpha channel for debugging
181
- alpha = np.array(processed_image)[:, :, 3]
182
- print(f"[debug] Alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, unique: {np.unique(alpha)}")
183
-
184
- # Prepare batch for processing
185
- batch = create_batch(processed_image)
186
- batch = {k: v.to(device) for k, v in batch.items()}
187
-
188
- # Generate point cloud
189
- pc_cond = forward_model(
190
- batch,
191
- spar3d_model,
192
- guidance_scale=3.0,
193
- seed=seed,
194
- device=device
195
- )
196
- batch["pc_cond"] = pc_cond
197
-
198
- # Generate mesh
199
- with torch.no_grad():
200
- with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
201
- trimesh_mesh, _ = spar3d_model.generate_mesh(
202
- batch,
203
- 1024, # texture_resolution
204
- remesh="none",
205
- vertex_count=-1,
206
- estimate_illumination=True
207
- )
208
- trimesh_mesh = trimesh_mesh[0]
209
-
210
- # Export to GLB
211
- temp_dir = tempfile.mkdtemp()
212
- output_path = os.path.join(temp_dir, 'output.glb')
213
-
214
- trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
215
-
216
- return output_path
217
-
218
- except Exception as e:
219
- print(f"Error during generation: {str(e)}")
220
- import traceback
221
- traceback.print_exc()
222
- return None
223
-
224
- # Create Gradio app using Blocks
225
- with gr.Blocks() as demo:
226
- gr.Markdown("This space is based on [Stable Point-Aware 3D](https://huggingface.co/spaces/stabilityai/stable-point-aware-3d) by Stability AI, [Text to 3D](https://huggingface.co/spaces/jbilcke-hf/text-to-3d) by jbilcke-hf.")
227
-
228
- with gr.Row():
229
- input_img = gr.Image(
230
- type="pil", label="Input Image", sources="upload", image_mode="RGBA"
231
- )
232
-
233
- with gr.Row():
234
- model_output = gr.Model3D(
235
- label="Generated .GLB model",
236
- clear_color=[0.0, 0.0, 0.0, 0.0],
237
- )
238
-
239
- # Event handler
240
- input_img.upload(
241
- fn=generate_and_process_3d,
242
- inputs=[input_img],
243
- outputs=[model_output],
244
- api_name="generate"
245
- )
246
-
247
- if __name__ == "__main__":
248
  demo.queue().launch()
 
1
+ import spaces
2
+ import os
3
+ import tempfile
4
+ from typing import Any
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ import gradio as gr
9
+ import trimesh
10
+ from transparent_background import Remover
11
+
12
+ import subprocess
13
+
14
+ def install_cuda_toolkit():
15
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
16
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
17
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
18
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
19
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
20
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
21
+
22
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
23
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
24
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
25
+ os.environ["CUDA_HOME"],
26
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
27
+ )
28
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
29
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
30
+
31
+ install_cuda_toolkit()
32
+
33
+ # Import and setup SPAR3D
34
+ os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
35
+ import spar3d.utils as spar3d_utils
36
+ from spar3d.system import SPAR3D
37
+
38
+ # Constants
39
+ COND_WIDTH = 512
40
+ COND_HEIGHT = 512
41
+ COND_DISTANCE = 2.2
42
+ COND_FOVY = 0.591627
43
+ BACKGROUND_COLOR = [0.5, 0.5, 0.5]
44
+
45
+ # Initialize models
46
+ device = spar3d_utils.get_device()
47
+ bg_remover = Remover()
48
+ spar3d_model = SPAR3D.from_pretrained(
49
+ "stabilityai/stable-point-aware-3d",
50
+ config_name="config.yaml",
51
+ weight_name="model.safetensors"
52
+ ).eval().to(device)
53
+
54
+ # Initialize camera parameters
55
+ c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
56
+ intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
57
+ COND_FOVY, COND_HEIGHT, COND_WIDTH
58
+ )
59
+
60
+ def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.Image:
61
+ """Create an RGBA image from RGB image and optional mask."""
62
+ rgba_image = rgb_image.convert('RGBA')
63
+ if mask is not None:
64
+ # Ensure mask is 2D before converting to alpha
65
+ if len(mask.shape) > 2:
66
+ mask = mask.squeeze()
67
+ alpha = Image.fromarray((mask * 255).astype(np.uint8))
68
+ rgba_image.putalpha(alpha)
69
+ return rgba_image
70
+
71
+ def create_batch(input_image: Image.Image) -> dict[str, Any]:
72
+ """Prepare image batch for model input."""
73
+ # Resize and convert input image to numpy array
74
+ resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
75
+ img_array = np.array(resized_image).astype(np.float32) / 255.0
76
+
77
+ # Extract RGB and alpha channels
78
+ if img_array.shape[-1] == 4: # RGBA
79
+ rgb = img_array[..., :3]
80
+ mask = img_array[..., 3:4]
81
+ else: # RGB
82
+ rgb = img_array
83
+ mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
84
+
85
+ # Convert to tensors while keeping channel-last format
86
+ rgb = torch.from_numpy(rgb).float() # [H, W, 3]
87
+ mask = torch.from_numpy(mask).float() # [H, W, 1]
88
+
89
+ # Create background blend (match channel-last format)
90
+ bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3) # [1, 1, 3]
91
+
92
+ # Blend RGB with background using mask (all in channel-last format)
93
+ rgb_cond = torch.lerp(bg_tensor, rgb, mask) # [H, W, 3]
94
+
95
+ # Move channels to correct dimension and add batch dimension
96
+ # Important: For SPAR3D image tokenizer, we need [B, H, W, C] format
97
+ rgb_cond = rgb_cond.unsqueeze(0) # [1, H, W, 3]
98
+ mask = mask.unsqueeze(0) # [1, H, W, 1]
99
+
100
+ # Create the batch dictionary
101
+ batch = {
102
+ "rgb_cond": rgb_cond, # [1, H, W, 3]
103
+ "mask_cond": mask, # [1, H, W, 1]
104
+ "c2w_cond": c2w_cond.unsqueeze(0), # [1, 4, 4]
105
+ "intrinsic_cond": intrinsic.unsqueeze(0), # [1, 3, 3]
106
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), # [1, 3, 3]
107
+ }
108
+
109
+ for k, v in batch.items():
110
+ print(f"[debug] {k} final shape:", v.shape)
111
+
112
+ return batch
113
+
114
+ def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
115
+ """Process batch through model and generate point cloud."""
116
+
117
+ batch_size = batch["rgb_cond"].shape[0]
118
+ assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
119
+
120
+ # Generate point cloud tokens
121
+ try:
122
+ cond_tokens = system.forward_pdiff_cond(batch)
123
+ except Exception as e:
124
+ print("\n[ERROR] Failed in forward_pdiff_cond:")
125
+ print(e)
126
+ print("\nInput tensor properties:")
127
+ print("rgb_cond dtype:", batch["rgb_cond"].dtype)
128
+ print("rgb_cond device:", batch["rgb_cond"].device)
129
+ print("rgb_cond requires_grad:", batch["rgb_cond"].requires_grad)
130
+ raise
131
+
132
+ # Sample points
133
+ sample_iter = system.sampler.sample_batch_progressive(
134
+ batch_size,
135
+ cond_tokens,
136
+ guidance_scale=guidance_scale,
137
+ device=device
138
+ )
139
+
140
+ # Get final samples
141
+ for x in sample_iter:
142
+ samples = x["xstart"]
143
+
144
+ pc_cond = samples.permute(0, 2, 1).float()
145
+
146
+ # Normalize point cloud
147
+ pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)
148
+
149
+ # Subsample to 512 points
150
+ pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]
151
+
152
+ return pc_cond
153
+
154
+ @spaces.GPU
155
+ @torch.inference_mode()
156
+ def generate_and_process_3d(image: Image.Image) -> tuple[str | None, Image.Image | None]:
157
+ """Generate image from prompt and convert to 3D model."""
158
+
159
+ # Generate random seed
160
+ seed = np.random.randint(0, np.iinfo(np.int32).max)
161
+
162
+ try:
163
+ rgb_image = image.convert('RGB')
164
+
165
+ # bg_remover returns a PIL Image already, no need to convert
166
+ no_bg_image = bg_remover.process(rgb_image)
167
+ print(f"[debug] no_bg_image type: {type(no_bg_image)}, mode: {no_bg_image.mode}")
168
+
169
+ # Convert to RGBA if not already
170
+ rgba_image = no_bg_image.convert('RGBA')
171
+ print(f"[debug] rgba_image mode: {rgba_image.mode}")
172
+
173
+ processed_image = spar3d_utils.foreground_crop(
174
+ rgba_image,
175
+ crop_ratio=1.3,
176
+ newsize=(COND_WIDTH, COND_HEIGHT),
177
+ no_crop=False
178
+ )
179
+
180
+ # Show the processed image alpha channel for debugging
181
+ alpha = np.array(processed_image)[:, :, 3]
182
+ print(f"[debug] Alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, unique: {np.unique(alpha)}")
183
+
184
+ # Prepare batch for processing
185
+ batch = create_batch(processed_image)
186
+ batch = {k: v.to(device) for k, v in batch.items()}
187
+
188
+ # Generate point cloud
189
+ pc_cond = forward_model(
190
+ batch,
191
+ spar3d_model,
192
+ guidance_scale=3.0,
193
+ seed=seed,
194
+ device=device
195
+ )
196
+ batch["pc_cond"] = pc_cond
197
+
198
+ # Generate mesh
199
+ with torch.no_grad():
200
+ with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
201
+ trimesh_mesh, _ = spar3d_model.generate_mesh(
202
+ batch,
203
+ 1024, # texture_resolution
204
+ remesh="none",
205
+ vertex_count=-1,
206
+ estimate_illumination=True
207
+ )
208
+ trimesh_mesh = trimesh_mesh[0]
209
+
210
+ # Export to GLB
211
+ temp_dir = tempfile.mkdtemp()
212
+ output_path = os.path.join(temp_dir, 'mesh.glb')
213
+
214
+ trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
215
+
216
+ return output_path
217
+
218
+ except Exception as e:
219
+ print(f"Error during generation: {str(e)}")
220
+ import traceback
221
+ traceback.print_exc()
222
+ return None
223
+
224
+ # Create Gradio app using Blocks
225
+ with gr.Blocks() as demo:
226
+ gr.Markdown("This space is based on [Stable Point-Aware 3D](https://huggingface.co/spaces/stabilityai/stable-point-aware-3d) by Stability AI, [Text to 3D](https://huggingface.co/spaces/jbilcke-hf/text-to-3d) by jbilcke-hf.")
227
+
228
+ with gr.Row():
229
+ input_img = gr.Image(
230
+ type="pil", label="Input Image", sources="upload", image_mode="RGBA"
231
+ )
232
+
233
+ with gr.Row():
234
+ model_output = gr.Model3D(
235
+ label="Generated .GLB model",
236
+ clear_color=[0.0, 0.0, 0.0, 0.0],
237
+ )
238
+
239
+ # Event handler
240
+ input_img.upload(
241
+ fn=generate_and_process_3d,
242
+ inputs=[input_img],
243
+ outputs=[model_output],
244
+ api_name="generate"
245
+ )
246
+
247
+ if __name__ == "__main__":
248
  demo.queue().launch()