SemanticBoost / TADA /lib /common /mesh_utils.py
kleinhe
init
c3d0293
import torch
import kaolin as kal
import numpy as np
from pathlib import Path
import numpy as np
import pymeshlab as pml
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
def get_camera_from_view(elev, azim, r=3.0):
x = r * torch.cos(azim) * torch.sin(elev)
y = r * torch.sin(azim) * torch.sin(elev)
z = r * torch.cos(elev)
# print(elev,azim,x,y,z)
pos = torch.tensor([x, y, z]).unsqueeze(0)
look_at = -pos
direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0)
camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction)
return camera_proj
def get_camera_from_view2(elev, azim, r=3.0):
x = r * torch.cos(elev) * torch.cos(azim)
y = r * torch.sin(elev)
z = r * torch.cos(elev) * torch.sin(azim)
# print(elev,azim,x,y,z)
pos = torch.tensor([x, y, z]).unsqueeze(0)
look_at = -pos
direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0)
camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction)
return camera_proj
def get_homogenous_coordinates(V):
N, D = V.shape
bottom = torch.ones(N, device=device).unsqueeze(1)
return torch.cat([V, bottom], dim=1)
def apply_affine(verts, A):
verts = verts.to(device)
verts = get_homogenous_coordinates(verts)
A = torch.cat([A, torch.tensor([0.0, 0.0, 0.0, 1.0], device=device).unsqueeze(0)], dim=0)
transformed_verts = A @ verts.T
transformed_verts = transformed_verts[:-1]
return transformed_verts.T
def standardize_mesh(mesh):
verts = mesh.vertices
center = verts.mean(dim=0)
verts -= center
scale = torch.std(torch.norm(verts, p=2, dim=1))
verts /= scale
mesh.vertices = verts
return mesh
def normalize_mesh(mesh):
verts = mesh.vertices
# Compute center of bounding box
# center = torch.mean(torch.column_stack([torch.max(verts, dim=0)[0], torch.min(verts, dim=0)[0]]))
center = verts.mean(dim=0)
verts = verts - center
scale = torch.max(torch.norm(verts, p=2, dim=1))
verts = verts / scale
mesh.vertices = verts
return mesh
def get_texture_map_from_color(mesh, color, H=224, W=224):
num_faces = mesh.faces.shape[0]
texture_map = torch.zeros(1, H, W, 3).to(device)
texture_map[:, :, :] = color
return texture_map.permute(0, 3, 1, 2)
def get_face_attributes_from_color(mesh, color):
num_faces = mesh.faces.shape[0]
face_attributes = torch.zeros(1, num_faces, 3, 3).to(device)
face_attributes[:, :, :] = color
return face_attributes
def sample_bary(faces, vertices):
num_faces = faces.shape[0]
num_vertices = vertices.shape[0]
# get random barycentric for each face TODO: improve sampling
A = torch.randn(num_faces)
B = torch.randn(num_faces) * (1 - A)
C = 1 - (A + B)
bary = torch.vstack([A, B, C]).to(device)
# compute xyz of new vertices and new uvs (if mesh has them)
new_vertices = torch.zeros(num_faces, 3).to(device)
new_uvs = torch.zeros(num_faces, 2).to(device)
face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces)
for f in range(num_faces):
new_vertices[f] = bary[:, f] @ face_verts[:, f]
new_vertices = torch.cat([vertices, new_vertices])
return new_vertices
def _add_vertices(mesh):
faces = torch.as_tensor(mesh.faces)
vertices = torch.as_tensor(mesh.vertices)
num_faces = faces.shape[0]
num_vertices = vertices.shape[0]
# get random barycentric for each face TODO: improve sampling
A = torch.randn(num_faces)
B = torch.randn(num_faces) * (1 - A)
C = 1 - (A + B)
bary = torch.vstack([A, B, C]).to(device)
# compute xyz of new vertices and new uvs (if mesh has them)
new_vertices = torch.zeros(num_faces, 3).to(device)
new_uvs = torch.zeros(num_faces, 2).to(device)
face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces)
face_uvs = mesh.face_uvs
for f in range(num_faces):
new_vertices[f] = bary[:, f] @ face_verts[:, f]
if face_uvs is not None:
new_uvs[f] = bary[:, f] @ face_uvs[:, f]
# update face and face_uvs of mesh
new_vertices = torch.cat([vertices, new_vertices])
new_faces = []
new_face_uvs = []
new_vertex_normals = []
for i in range(num_faces):
old_face = faces[i]
a, b, c = old_face[0], old_face[1], old_face[2]
d = num_vertices + i
new_faces.append(torch.tensor([a, b, d]).to(device))
new_faces.append(torch.tensor([a, d, c]).to(device))
new_faces.append(torch.tensor([d, b, c]).to(device))
if face_uvs is not None:
old_face_uvs = face_uvs[0, i]
a, b, c = old_face_uvs[0], old_face_uvs[1], old_face_uvs[2]
d = new_uvs[i]
new_face_uvs.append(torch.vstack([a, b, d]))
new_face_uvs.append(torch.vstack([a, d, c]))
new_face_uvs.append(torch.vstack([d, b, c]))
if mesh.face_normals is not None:
new_vertex_normals.append(mesh.face_normals[i])
else:
e1 = vertices[b] - vertices[a]
e2 = vertices[c] - vertices[a]
norm = torch.cross(e1, e2)
norm /= torch.norm(norm)
# Double check sign against existing vertex normals
if torch.dot(norm, mesh.vertex_normals[a]) < 0:
norm = -norm
new_vertex_normals.append(norm)
vertex_normals = torch.cat([mesh.vertex_normals, torch.stack(new_vertex_normals)])
if face_uvs is not None:
new_face_uvs = torch.vstack(new_face_uvs).unsqueeze(0).view(1, 3 * num_faces, 3, 2)
new_faces = torch.vstack(new_faces)
return new_vertices, new_faces, vertex_normals, new_face_uvs
def get_rgb_per_vertex(vertices, faces, face_rgbs):
num_vertex = vertices.shape[0]
num_faces = faces.shape[0]
vertex_color = torch.zeros(num_vertex, 3)
for v in range(num_vertex):
for f in range(num_faces):
face = num_faces[f]
if v in face:
vertex_color[v] = face_rgbs[f]
return face_rgbs
def get_barycentric(p, faces):
# faces num_points x 3 x 3
# p num_points x 3
# source: https://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates
a, b, c = faces[:, 0], faces[:, 1], faces[:, 2]
v0, v1, v2 = b - a, c - a, p - a
d00 = torch.sum(v0 * v0, dim=1)
d01 = torch.sum(v0 * v1, dim=1)
d11 = torch.sum(v1 * v1, dim=1)
d20 = torch.sum(v2 * v0, dim=1)
d21 = torch.sum(v2 * v1, dim=1)
denom = d00 * d11 - d01 * d01
v = (d11 * d20 - d01 * d21) / denom
w = (d00 * d21 - d01 * d20) / denom
u = 1 - (w + v)
return torch.vstack([u, v, w]).T
def get_uv_assignment(num_faces):
M = int(np.ceil(np.sqrt(num_faces)))
uv_map = torch.zeros(1, num_faces, 3, 2).to(device)
px, py = 0, 0
count = 0
for i in range(M):
px = 0
for j in range(M):
uv_map[:, count] = torch.tensor([[px, py],
[px + 1, py],
[px + 1, py + 1]])
px += 2
count += 1
if count >= num_faces:
hw = torch.max(uv_map.view(-1, 2), dim=0)[0]
uv_map = (uv_map - hw / 2.0) / (hw / 2)
return uv_map
py += 2
def get_texture_visual(res, nt, mesh):
faces_vt = kal.ops.mesh.index_vertices_by_faces(mesh.vertices.unsqueeze(0), mesh.faces).squeeze(0)
# as to not include encpoint, gen res+1 points and take first res
uv = torch.cartesian_prod(torch.linspace(-1, 1, res + 1)[:-1], torch.linspace(-1, 1, res + 1))[:-1].to(device)
image = torch.zeros(res, res, 3).to(device)
# image[:,:,:] = torch.tensor([0.0,1.0,0.0]).to(device)
image = image.permute(2, 0, 1)
num_faces = mesh.faces.shape[0]
uv_map = get_uv_assignment(num_faces).squeeze(0)
zero = torch.tensor([0.0, 0.0, 0.0]).to(device)
one = torch.tensor([1.0, 1.0, 1.0]).to(device)
for face in range(num_faces):
bary = get_barycentric(uv, uv_map[face].repeat(len(uv), 1, 1))
maskA = torch.logical_and(bary[:, 0] >= 0.0, bary[:, 0] <= 1.0)
maskB = torch.logical_and(bary[:, 1] >= 0.0, bary[:, 1] <= 1.0)
maskC = torch.logical_and(bary[:, 2] >= 0.0, bary[:, 2] <= 1.0)
mask = torch.logical_and(maskA, maskB)
mask = torch.logical_and(maskC, mask)
inside_triangle = bary[mask]
inside_triangle_uv = inside_triangle @ uv_map[face]
inside_triangle_xyz = inside_triangle @ faces_vt[face]
inside_triangle_rgb = nt(inside_triangle_xyz)
pixels = (inside_triangle_uv + 1.0) / 2.0
pixels = pixels * res
pixels = torch.floor(pixels).type(torch.int64)
image[:, pixels[:, 0], pixels[:, 1]] = inside_triangle_rgb.T
return image
# Get rotation matrix about vector through origin
def getRotMat(axis, theta):
"""
axis: np.array, normalized vector
theta: radians
"""
import math
axis = axis / np.linalg.norm(axis)
cprod = np.array([[0, -axis[2], axis[1]],
[axis[2], 0, -axis[0]],
[-axis[1], axis[0], 0]])
rot = math.cos(theta) * np.identity(3) + math.sin(theta) * cprod + \
(1 - math.cos(theta)) * np.outer(axis, axis)
return rot
# Map vertices and subset of faces to 0-indexed vertices, keeping only relevant vertices
def trimMesh(vertices, faces):
unique_v = np.sort(np.unique(faces.flatten()))
v_val = np.arange(len(unique_v))
v_map = dict(zip(unique_v, v_val))
new_faces = np.array([v_map[i] for i in faces.flatten()]).reshape(faces.shape[0], faces.shape[1])
new_v = vertices[unique_v]
return new_v, new_faces
# ================== VISUALIZATION =======================
# Back out camera parameters from view transform matrix
def extract_from_gl_viewmat(gl_mat):
gl_mat = gl_mat.reshape(4, 4)
s = gl_mat[0, :3]
u = gl_mat[1, :3]
f = -1 * gl_mat[2, :3]
coord = gl_mat[:3, 3] # first 3 entries of the last column
camera_location = np.array([-s, -u, f]).T @ coord
target = camera_location + f * 10 # any scale
return camera_location, target
def psScreenshot(vertices, faces, axis, angles, save_path, name="mesh", frame_folder="frames", scalars=None,
colors=None,
defined_on="faces", highlight_faces=None, highlight_color=[1, 0, 0], highlight_radius=None,
cmap=None, sminmax=None, cpos=None, clook=None, save_video=False, save_base=False,
ground_plane="tile_reflection", debug=False, edge_color=[0, 0, 0], edge_width=1, material=None):
import polyscope as ps
ps.init()
# Set camera to look at same fixed position in centroid of original mesh
# center = np.mean(vertices, axis = 0)
# pos = center + np.array([0, 0, 3])
# ps.look_at(pos, center)
ps.set_ground_plane_mode(ground_plane)
frame_path = f"{save_path}/{frame_folder}"
if save_base == True:
ps_mesh = ps.register_surface_mesh("mesh", vertices, faces, enabled=True,
edge_color=edge_color, edge_width=edge_width, material=material)
ps.screenshot(f"{frame_path}/{name}.png")
ps.remove_all_structures()
Path(frame_path).mkdir(parents=True, exist_ok=True)
# Convert 2D to 3D by appending Z-axis
if vertices.shape[1] == 2:
vertices = np.concatenate((vertices, np.zeros((len(vertices), 1))), axis=1)
for i in range(len(angles)):
rot = getRotMat(axis, angles[i])
rot_verts = np.transpose(rot @ np.transpose(vertices))
ps_mesh = ps.register_surface_mesh("mesh", rot_verts, faces, enabled=True,
edge_color=edge_color, edge_width=edge_width, material=material)
if scalars is not None:
ps_mesh.add_scalar_quantity(f"scalar", scalars, defined_on=defined_on,
cmap=cmap, enabled=True, vminmax=sminmax)
if colors is not None:
ps_mesh.add_color_quantity(f"color", colors, defined_on=defined_on,
enabled=True)
if highlight_faces is not None:
# Create curve to highlight faces
curve_v, new_f = trimMesh(rot_verts, faces[highlight_faces, :])
curve_edges = []
for face in new_f:
curve_edges.extend(
[[face[0], face[1]], [face[1], face[2]], [face[2], face[0]]])
curve_edges = np.array(curve_edges)
ps_curve = ps.register_curve_network("curve", curve_v, curve_edges, color=highlight_color,
radius=highlight_radius)
if cpos is None or clook is None:
ps.reset_camera_to_home_view()
else:
ps.look_at(cpos, clook)
if debug == True:
ps.show()
ps.screenshot(f"{frame_path}/{name}_{i}.png")
ps.remove_all_structures()
if save_video == True:
import glob
from PIL import Image
fp_in = f"{frame_path}/{name}_*.png"
fp_out = f"{save_path}/{name}.gif"
img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
img.save(fp=fp_out, format='GIF', append_images=imgs,
save_all=True, duration=200, loop=0)
# ================== POSITIONAL ENCODERS =============================
class FourierFeatureTransform(torch.nn.Module):
"""
An implementation of Gaussian Fourier feature mapping.
"Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
https://arxiv.org/abs/2006.10739
https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
Given an input of size [batches, num_input_channels, width, height],
returns a tensor of size [batches, mapping_size*2, width, height].
"""
def __init__(self, num_input_channels, mapping_size=256, scale=10, exclude=0):
super().__init__()
self._num_input_channels = num_input_channels
self._mapping_size = mapping_size
self.exclude = exclude
B = torch.randn((num_input_channels, mapping_size)) * scale
B_sort = sorted(B, key=lambda x: torch.norm(x, p=2))
self._B = torch.stack(B_sort) # for sape
def forward(self, x):
# assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim())
batches, channels = x.shape
assert channels == self._num_input_channels, \
"Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)
# Make shape compatible for matmul with _B.
# From [B, C, W, H] to [(B*W*H), C].
# x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)
res = x @ self._B.to(x.device)
# From [(B*W*H), C] to [B, W, H, C]
# x = x.view(batches, width, height, self._mapping_size)
# From [B, W, H, C] to [B, C, W, H]
# x = x.permute(0, 3, 1, 2)
res = 2 * np.pi * res
return torch.cat([x, torch.sin(res), torch.cos(res)], dim=1)
def poisson_mesh_reconstruction(points, normals=None):
# points/normals: [N, 3] np.ndarray
import open3d as o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
# outlier removal
pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)
# normals
if normals is None:
pcd.estimate_normals()
else:
pcd.normals = o3d.utility.Vector3dVector(normals[ind])
# visualize
o3d.visualization.draw_geometries([pcd], point_show_normal=False)
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9)
# vertices_to_remove = densities < np.quantile(densities, 0.1)
# mesh.remove_vertices_by_mask(vertices_to_remove)
# visualize
o3d.visualization.draw_geometries([mesh])
vertices = np.asarray(mesh.vertices)
triangles = np.asarray(mesh.triangles)
print(f'[INFO] poisson mesh reconstruction: {points.shape} --> V: {vertices.shape} / F:{triangles.shape}')
return vertices, triangles
def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True):
# optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
_ori_vert_shape = verts.shape
_ori_face_shape = faces.shape
if backend == 'pyfqmr':
import pyfqmr
solver = pyfqmr.Simplify()
solver.setMesh(verts, faces)
solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
verts, faces, normals = solver.getMesh()
else:
m = pml.Mesh(verts, faces)
ms = pml.MeshSet()
ms.add_mesh(m, 'mesh') # will copy!
# filters
# ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement)
if remesh:
# ms.apply_coord_taubin_smoothing()
ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1))
# extract mesh
m = ms.current_mesh()
verts = m.vertex_matrix()
faces = m.face_matrix()
print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
return verts, faces
def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=False):
# verts: [N, 3]
# faces: [N, 3]
_ori_vert_shape = verts.shape
_ori_face_shape = faces.shape
m = pml.Mesh(verts, faces)
ms = pml.MeshSet()
ms.add_mesh(m, 'mesh') # will copy!
# filters
ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
if v_pct > 0:
ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal
ms.meshing_remove_duplicate_faces() # faces defined by the same verts
ms.meshing_remove_null_faces() # faces with area == 0
if min_d > 0:
ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d))
if min_f > 0:
ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
if repair:
# ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
ms.meshing_repair_non_manifold_edges(method=0)
ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
if remesh:
# ms.apply_coord_taubin_smoothing()
ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1))
# extract mesh
m = ms.current_mesh()
verts = m.vertex_matrix()
faces = m.face_matrix()
print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
return verts, faces
def laplace_regularizer_const(v_pos, t_pos_idx):
term = torch.zeros_like(v_pos)
norm = torch.zeros_like(v_pos[..., 0:1])
v0 = v_pos[t_pos_idx[:, 0], :]
v1 = v_pos[t_pos_idx[:, 1], :]
v2 = v_pos[t_pos_idx[:, 2], :]
term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1, 3), (v1 - v0) + (v2 - v0))
term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1, 3), (v0 - v1) + (v2 - v1))
term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1, 3), (v0 - v2) + (v1 - v2))
two = torch.ones_like(v0) * 2.0
norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
norm.scatter_add_(0, t_pos_idx[:, 2:3], two)
term = term / torch.clamp(norm, min=1.0)
return torch.mean(term ** 2)