aca2024 commited on
Commit
045e960
·
1 Parent(s): e28f1ec

update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -215
app.py CHANGED
@@ -12,15 +12,9 @@ from spann3r.datasets import Demo
12
  from torch.utils.data import DataLoader
13
  import trimesh
14
  from scipy.spatial.transform import Rotation
15
- from transformers import AutoModelForImageSegmentation
16
- from torchvision import transforms
17
- from PIL import Image
18
- import open3d as o3d
19
- from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
20
-
21
 
22
  # Default values
23
- DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
24
  DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
25
  DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
26
 
@@ -29,45 +23,15 @@ OPENGL = np.array([[1, 0, 0, 0],
29
  [0, 0, -1, 0],
30
  [0, 0, 0, 1]])
31
 
32
- def export_geometry(geometry, as_pointcloud=False):
33
- if as_pointcloud:
34
- if not isinstance(geometry, o3d.geometry.PointCloud):
35
- raise ValueError("Expected an Open3D PointCloud object when as_pointcloud is True")
36
- output_path = tempfile.mktemp(suffix='.ply')
37
- else:
38
- if not isinstance(geometry, o3d.geometry.TriangleMesh):
39
- raise ValueError("Expected an Open3D TriangleMesh object when as_pointcloud is False")
40
- output_path = tempfile.mktemp(suffix='.obj')
41
-
42
- # Apply rotation
43
- rot = np.eye(4)
44
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
45
- transform = np.linalg.inv(OPENGL @ rot)
46
- geometry.transform(transform)
47
-
48
- # Export the geometry
49
- if as_pointcloud:
50
- o3d.io.write_point_cloud(output_path, geometry, write_ascii=False, compressed=True)
51
- else:
52
- o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True)
53
-
54
- return output_path
55
-
56
-
57
- def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str:
58
  temp_dir = tempfile.mkdtemp()
59
  output_path = os.path.join(temp_dir, "%03d.jpg")
60
-
61
- filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}"
62
-
63
  command = [
64
  "ffmpeg",
65
  "-i", video_path,
66
- "-vf", filter_complex,
67
- "-vsync", "0",
68
  output_path
69
  ]
70
-
71
  subprocess.run(command, check=True)
72
  return temp_dir
73
 
@@ -141,42 +105,9 @@ def pts3d_to_trimesh(img, pts3d, valid=None):
141
  return dict(vertices=vertices, face_colors=face_colors, faces=faces)
142
 
143
  model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
144
- birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
145
- birefnet.to(DEFAULT_DEVICE)
146
- birefnet.eval()
147
 
148
- def extract_object(birefnet, image):
149
- # Data settings
150
- image_size = (1024, 1024)
151
- transform_image = transforms.Compose([
152
- transforms.Resize(image_size),
153
- transforms.ToTensor(),
154
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
155
- ])
156
-
157
- input_images = transform_image(image).unsqueeze(0).to(DEFAULT_DEVICE)
158
-
159
- # Prediction
160
- with torch.no_grad():
161
- preds = birefnet(input_images)[-1].sigmoid().cpu()
162
- pred = preds[0].squeeze()
163
- pred_pil = transforms.ToPILImage()(pred)
164
- mask = pred_pil.resize(image.size)
165
- return mask
166
-
167
- def generate_mask(image: np.ndarray):
168
- # Convert numpy array to PIL Image
169
- pil_image = Image.fromarray((image * 255).astype(np.uint8))
170
-
171
- # Extract object and get mask
172
- mask = extract_object(birefnet, pil_image)
173
-
174
- # Convert mask to numpy array
175
- mask_np = np.array(mask) / 255.0
176
- return mask_np
177
  @torch.no_grad()
178
- def reconstruct(video_path, conf_thresh, kf_every,
179
- as_pointcloud=False, remove_background=False, refine=False):
180
  # Extract frames from video
181
  demo_path = extract_frames(video_path)
182
 
@@ -197,156 +128,67 @@ def reconstruct(video_path, conf_thresh, kf_every,
197
  fps = len(batch) / (end - start)
198
  print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
199
 
200
- try:
201
- # Process results
202
- pcds = []
203
- for j, view in enumerate(batch):
204
- image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
205
- image = (image + 1) / 2
206
- pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
207
- pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
208
- conf = preds[j]['conf'][0].cpu().data.numpy()
209
- conf_sig = (conf - 1) / conf
210
- if remove_background:
211
- mask = generate_mask(image)
212
- else:
213
- mask = np.ones_like(conf)
214
-
215
- combined_mask = (conf_sig > conf_thresh) & (mask > 0.5)
216
-
217
- pcd = o3d.geometry.PointCloud()
218
- pcd.points = o3d.utility.Vector3dVector(pts[combined_mask])
219
- pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
220
- pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
221
- pcds.append(pcd)
222
- except Exception as e:
223
- print(repr(e))
224
-
225
- print(f'Finished Process results {demo_name}')
226
 
227
- pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
228
 
229
  if as_pointcloud:
230
- o3d_geometry = pcd_combined
 
 
 
 
231
  else:
232
- o3d_geometry = point2mesh(pcd_combined)
233
-
234
- # Create coarse result
235
-
236
- print(f'Create coarse result {demo_name}')
237
-
238
- coarse_output_path = export_geometry(o3d_geometry, as_pointcloud)
239
 
240
- print(f'Finished Create coarse result {demo_name}')
241
-
242
- yield coarse_output_path, None
 
 
 
 
243
 
244
- if refine:
245
- # Perform global optimization
246
- print("Performing global registration...")
247
- transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.001)
248
-
249
- if as_pointcloud:
250
- o3d_geometry = transformed_pcds
251
- else:
252
- o3d_geometry = point2mesh(transformed_pcds)
253
-
254
- # Create coarse result
255
- refined_output_path = export_geometry(o3d_geometry, as_pointcloud)
256
-
257
- print(f'Perform global optimization {demo_name}')
258
-
259
- yield coarse_output_path, refined_output_path
260
-
261
  # Clean up temporary directory
262
  os.system(f"rm -rf {demo_path}")
263
-
264
- # Update the Gradio interface with improved layout
265
- with gr.Blocks(
266
- title="StableSpann3r: Making Spann3r stable with Odometry Backend",
267
- css="""
268
- #download {
269
- height: 118px;
270
- }
271
- .slider .inner {
272
- width: 5px;
273
- background: #FFF;
274
- }
275
- .viewport {
276
- aspect-ratio: 4/3;
277
- }
278
- .tabs button.selected {
279
- font-size: 20px !important;
280
- color: crimson !important;
281
- }
282
- h1 {
283
- text-align: center;
284
- display: block;
285
- }
286
- h2 {
287
- text-align: center;
288
- display: block;
289
- }
290
- h3 {
291
- text-align: center;
292
- display: block;
293
- }
294
- .md_feedback li {
295
- margin-bottom: 0px !important;
296
- }
297
- """,
298
- head="""
299
- <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
300
- <script>
301
- window.dataLayer = window.dataLayer || [];
302
- function gtag() {dataLayer.push(arguments);}
303
- gtag('js', new Date());
304
- gtag('config', 'G-1FWSVCGZTG');
305
- </script>
306
- """,
307
- ) as iface:
308
- gr.Markdown(
309
- """
310
- # StableSpann3r: Making Spann3r stable with Odometry Backend
311
- <p align="center">
312
- <a title="Website" href="https://stable-x.github.io/StableSpann3r/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
313
- <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
314
- </a>
315
- <a title="arXiv" href="https://arxiv.org/abs/XXXX.XXXXX" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
316
- <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
317
- </a>
318
- <a title="Github" href="https://github.com/Stable-X/StableSpann3r" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
319
- <img src="https://img.shields.io/github/stars/Stable-X/StableSpann3r?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
320
- </a>
321
- <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
322
- <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
323
- </a>
324
- </p>
325
- """
326
- )
327
- with gr.Row():
328
- with gr.Column(scale=1):
329
- video_input = gr.Video(label="Input Video")
330
- with gr.Row():
331
- conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold")
332
- kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
333
- with gr.Row():
334
- remove_background = gr.Checkbox(label="Remove Background", value=False)
335
- refine = gr.Checkbox(label="Enable Backend", value=False)
336
- as_pointcloud = gr.Checkbox(label="As Pointcloud", value=False)
337
- reconstruct_btn = gr.Button("Reconstruct")
338
-
339
- with gr.Column(scale=2):
340
- with gr.Tab("Coarse Model"):
341
- coarse_model = gr.Model3D(label="Coarse 3D Model", display_mode="solid", clear_color=[0.0, 0.0, 0.0, 0.0])
342
- with gr.Tab("Refined Model"):
343
- refined_model = gr.Model3D(label="Refined 3D Model", display_mode="solid", clear_color=[0.0, 0.0, 0.0, 0.0])
344
 
345
- reconstruct_btn.click(
346
- fn=reconstruct,
347
- inputs=[video_input, conf_thresh, kf_every, as_pointcloud, remove_background, refine],
348
- outputs=[coarse_model, refined_model]
349
- )
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  if __name__ == "__main__":
352
- iface.launch(server_name="0.0.0.0")
 
12
  from torch.utils.data import DataLoader
13
  import trimesh
14
  from scipy.spatial.transform import Rotation
 
 
 
 
 
 
15
 
16
  # Default values
17
+ DEFAULT_CKPT_PATH = 'https://huggingface.co/spaces/Stable-X/StableSpann3R/resolve/main/checkpoints/spann3r.pth'
18
  DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
19
  DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
20
 
 
23
  [0, 0, -1, 0],
24
  [0, 0, 0, 1]])
25
 
26
+ def extract_frames(video_path: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  temp_dir = tempfile.mkdtemp()
28
  output_path = os.path.join(temp_dir, "%03d.jpg")
 
 
 
29
  command = [
30
  "ffmpeg",
31
  "-i", video_path,
32
+ "-vf", "fps=1",
 
33
  output_path
34
  ]
 
35
  subprocess.run(command, check=True)
36
  return temp_dir
37
 
 
105
  return dict(vertices=vertices, face_colors=face_colors, faces=faces)
106
 
107
  model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
 
 
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  @torch.no_grad()
110
+ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False):
 
111
  # Extract frames from video
112
  demo_path = extract_frames(video_path)
113
 
 
128
  fps = len(batch) / (end - start)
129
  print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
130
 
131
+ # Process results
132
+ pts_all, images_all, conf_all = [], [], []
133
+ for j, view in enumerate(batch):
134
+ image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
135
+ pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
136
+ conf = preds[j]['conf'][0].cpu().data.numpy()
137
+
138
+ images_all.append((image[None, ...] + 1.0)/2.0)
139
+ pts_all.append(pts[None, ...])
140
+ conf_all.append(conf[None, ...])
141
+
142
+ images_all = np.concatenate(images_all, axis=0)
143
+ pts_all = np.concatenate(pts_all, axis=0) * 10
144
+ conf_all = np.concatenate(conf_all, axis=0)
145
+
146
+ # Create point cloud or mesh
147
+ conf_sig_all = (conf_all-1) / conf_all
148
+ mask = conf_sig_all > conf_thresh
 
 
 
 
 
 
 
 
149
 
150
+ scene = trimesh.Scene()
151
 
152
  if as_pointcloud:
153
+ pcd = trimesh.PointCloud(
154
+ vertices=pts_all[mask].reshape(-1, 3),
155
+ colors=images_all[mask].reshape(-1, 3)
156
+ )
157
+ scene.add_geometry(pcd)
158
  else:
159
+ meshes = []
160
+ for i in range(len(images_all)):
161
+ meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], mask[i]))
162
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
163
+ scene.add_geometry(mesh)
 
 
164
 
165
+ rot = np.eye(4)
166
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
167
+ scene.apply_transform(np.linalg.inv(OPENGL @ rot))
168
+
169
+ # Save the scene as GLB
170
+ output_path = tempfile.mktemp(suffix='.glb')
171
+ scene.export(output_path)
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # Clean up temporary directory
174
  os.system(f"rm -rf {demo_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ return output_path, f"Reconstruction completed. FPS: {fps:.2f}"
177
+
178
+ iface = gr.Interface(
179
+ fn=reconstruct,
180
+ inputs=[
181
+ gr.Video(label="Input Video"),
182
+ gr.Slider(0, 1, value=1e-3, label="Confidence Threshold"),
183
+ gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
184
+ gr.Checkbox(label="As Pointcloud", value=False)
185
+ ],
186
+ outputs=[
187
+ gr.Model3D(label="3D Model (GLB)", display_mode="solid"),
188
+ gr.Textbox(label="Status")
189
+ ],
190
+ title="3D Reconstruction with Spatial Memory",
191
+ )
192
 
193
  if __name__ == "__main__":
194
+ iface.launch()