Spaces:
Sleeping
Sleeping
Create tsr/models/isosurface.py
Browse files- tsr/models/isosurface.py +52 -0
tsr/models/isosurface.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchmcubes import marching_cubes
|
7 |
+
|
8 |
+
|
9 |
+
class IsosurfaceHelper(nn.Module):
|
10 |
+
points_range: Tuple[float, float] = (0, 1)
|
11 |
+
|
12 |
+
@property
|
13 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
14 |
+
raise NotImplementedError
|
15 |
+
|
16 |
+
|
17 |
+
class MarchingCubeHelper(IsosurfaceHelper):
|
18 |
+
def __init__(self, resolution: int) -> None:
|
19 |
+
super().__init__()
|
20 |
+
self.resolution = resolution
|
21 |
+
self.mc_func: Callable = marching_cubes
|
22 |
+
self._grid_vertices: Optional[torch.FloatTensor] = None
|
23 |
+
|
24 |
+
@property
|
25 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
26 |
+
if self._grid_vertices is None:
|
27 |
+
# keep the vertices on CPU so that we can support very large resolution
|
28 |
+
x, y, z = (
|
29 |
+
torch.linspace(*self.points_range, self.resolution),
|
30 |
+
torch.linspace(*self.points_range, self.resolution),
|
31 |
+
torch.linspace(*self.points_range, self.resolution),
|
32 |
+
)
|
33 |
+
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
|
34 |
+
verts = torch.cat(
|
35 |
+
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
|
36 |
+
).reshape(-1, 3)
|
37 |
+
self._grid_vertices = verts
|
38 |
+
return self._grid_vertices
|
39 |
+
|
40 |
+
def forward(
|
41 |
+
self,
|
42 |
+
level: torch.FloatTensor,
|
43 |
+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
44 |
+
level = -level.view(self.resolution, self.resolution, self.resolution)
|
45 |
+
try:
|
46 |
+
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
|
47 |
+
except AttributeError:
|
48 |
+
print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
|
49 |
+
v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
|
50 |
+
v_pos = v_pos[..., [2, 1, 0]]
|
51 |
+
v_pos = v_pos / (self.resolution - 1.0)
|
52 |
+
return v_pos.to(level.device), t_pos_idx.to(level.device)
|