import torch import kaolin as kal import numpy as np from pathlib import Path import numpy as np import pymeshlab as pml if torch.cuda.is_available(): device = torch.device("cuda:0") torch.cuda.set_device(device) else: device = torch.device("cpu") def get_camera_from_view(elev, azim, r=3.0): x = r * torch.cos(azim) * torch.sin(elev) y = r * torch.sin(azim) * torch.sin(elev) z = r * torch.cos(elev) # print(elev,azim,x,y,z) pos = torch.tensor([x, y, z]).unsqueeze(0) look_at = -pos direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0) camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction) return camera_proj def get_camera_from_view2(elev, azim, r=3.0): x = r * torch.cos(elev) * torch.cos(azim) y = r * torch.sin(elev) z = r * torch.cos(elev) * torch.sin(azim) # print(elev,azim,x,y,z) pos = torch.tensor([x, y, z]).unsqueeze(0) look_at = -pos direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0) camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction) return camera_proj def get_homogenous_coordinates(V): N, D = V.shape bottom = torch.ones(N, device=device).unsqueeze(1) return torch.cat([V, bottom], dim=1) def apply_affine(verts, A): verts = verts.to(device) verts = get_homogenous_coordinates(verts) A = torch.cat([A, torch.tensor([0.0, 0.0, 0.0, 1.0], device=device).unsqueeze(0)], dim=0) transformed_verts = A @ verts.T transformed_verts = transformed_verts[:-1] return transformed_verts.T def standardize_mesh(mesh): verts = mesh.vertices center = verts.mean(dim=0) verts -= center scale = torch.std(torch.norm(verts, p=2, dim=1)) verts /= scale mesh.vertices = verts return mesh def normalize_mesh(mesh): verts = mesh.vertices # Compute center of bounding box # center = torch.mean(torch.column_stack([torch.max(verts, dim=0)[0], torch.min(verts, dim=0)[0]])) center = verts.mean(dim=0) verts = verts - center scale = torch.max(torch.norm(verts, p=2, dim=1)) verts = verts / scale mesh.vertices = verts return mesh def get_texture_map_from_color(mesh, color, H=224, W=224): num_faces = mesh.faces.shape[0] texture_map = torch.zeros(1, H, W, 3).to(device) texture_map[:, :, :] = color return texture_map.permute(0, 3, 1, 2) def get_face_attributes_from_color(mesh, color): num_faces = mesh.faces.shape[0] face_attributes = torch.zeros(1, num_faces, 3, 3).to(device) face_attributes[:, :, :] = color return face_attributes def sample_bary(faces, vertices): num_faces = faces.shape[0] num_vertices = vertices.shape[0] # get random barycentric for each face TODO: improve sampling A = torch.randn(num_faces) B = torch.randn(num_faces) * (1 - A) C = 1 - (A + B) bary = torch.vstack([A, B, C]).to(device) # compute xyz of new vertices and new uvs (if mesh has them) new_vertices = torch.zeros(num_faces, 3).to(device) new_uvs = torch.zeros(num_faces, 2).to(device) face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces) for f in range(num_faces): new_vertices[f] = bary[:, f] @ face_verts[:, f] new_vertices = torch.cat([vertices, new_vertices]) return new_vertices def _add_vertices(mesh): faces = torch.as_tensor(mesh.faces) vertices = torch.as_tensor(mesh.vertices) num_faces = faces.shape[0] num_vertices = vertices.shape[0] # get random barycentric for each face TODO: improve sampling A = torch.randn(num_faces) B = torch.randn(num_faces) * (1 - A) C = 1 - (A + B) bary = torch.vstack([A, B, C]).to(device) # compute xyz of new vertices and new uvs (if mesh has them) new_vertices = torch.zeros(num_faces, 3).to(device) new_uvs = torch.zeros(num_faces, 2).to(device) face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces) face_uvs = mesh.face_uvs for f in range(num_faces): new_vertices[f] = bary[:, f] @ face_verts[:, f] if face_uvs is not None: new_uvs[f] = bary[:, f] @ face_uvs[:, f] # update face and face_uvs of mesh new_vertices = torch.cat([vertices, new_vertices]) new_faces = [] new_face_uvs = [] new_vertex_normals = [] for i in range(num_faces): old_face = faces[i] a, b, c = old_face[0], old_face[1], old_face[2] d = num_vertices + i new_faces.append(torch.tensor([a, b, d]).to(device)) new_faces.append(torch.tensor([a, d, c]).to(device)) new_faces.append(torch.tensor([d, b, c]).to(device)) if face_uvs is not None: old_face_uvs = face_uvs[0, i] a, b, c = old_face_uvs[0], old_face_uvs[1], old_face_uvs[2] d = new_uvs[i] new_face_uvs.append(torch.vstack([a, b, d])) new_face_uvs.append(torch.vstack([a, d, c])) new_face_uvs.append(torch.vstack([d, b, c])) if mesh.face_normals is not None: new_vertex_normals.append(mesh.face_normals[i]) else: e1 = vertices[b] - vertices[a] e2 = vertices[c] - vertices[a] norm = torch.cross(e1, e2) norm /= torch.norm(norm) # Double check sign against existing vertex normals if torch.dot(norm, mesh.vertex_normals[a]) < 0: norm = -norm new_vertex_normals.append(norm) vertex_normals = torch.cat([mesh.vertex_normals, torch.stack(new_vertex_normals)]) if face_uvs is not None: new_face_uvs = torch.vstack(new_face_uvs).unsqueeze(0).view(1, 3 * num_faces, 3, 2) new_faces = torch.vstack(new_faces) return new_vertices, new_faces, vertex_normals, new_face_uvs def get_rgb_per_vertex(vertices, faces, face_rgbs): num_vertex = vertices.shape[0] num_faces = faces.shape[0] vertex_color = torch.zeros(num_vertex, 3) for v in range(num_vertex): for f in range(num_faces): face = num_faces[f] if v in face: vertex_color[v] = face_rgbs[f] return face_rgbs def get_barycentric(p, faces): # faces num_points x 3 x 3 # p num_points x 3 # source: https://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates a, b, c = faces[:, 0], faces[:, 1], faces[:, 2] v0, v1, v2 = b - a, c - a, p - a d00 = torch.sum(v0 * v0, dim=1) d01 = torch.sum(v0 * v1, dim=1) d11 = torch.sum(v1 * v1, dim=1) d20 = torch.sum(v2 * v0, dim=1) d21 = torch.sum(v2 * v1, dim=1) denom = d00 * d11 - d01 * d01 v = (d11 * d20 - d01 * d21) / denom w = (d00 * d21 - d01 * d20) / denom u = 1 - (w + v) return torch.vstack([u, v, w]).T def get_uv_assignment(num_faces): M = int(np.ceil(np.sqrt(num_faces))) uv_map = torch.zeros(1, num_faces, 3, 2).to(device) px, py = 0, 0 count = 0 for i in range(M): px = 0 for j in range(M): uv_map[:, count] = torch.tensor([[px, py], [px + 1, py], [px + 1, py + 1]]) px += 2 count += 1 if count >= num_faces: hw = torch.max(uv_map.view(-1, 2), dim=0)[0] uv_map = (uv_map - hw / 2.0) / (hw / 2) return uv_map py += 2 def get_texture_visual(res, nt, mesh): faces_vt = kal.ops.mesh.index_vertices_by_faces(mesh.vertices.unsqueeze(0), mesh.faces).squeeze(0) # as to not include encpoint, gen res+1 points and take first res uv = torch.cartesian_prod(torch.linspace(-1, 1, res + 1)[:-1], torch.linspace(-1, 1, res + 1))[:-1].to(device) image = torch.zeros(res, res, 3).to(device) # image[:,:,:] = torch.tensor([0.0,1.0,0.0]).to(device) image = image.permute(2, 0, 1) num_faces = mesh.faces.shape[0] uv_map = get_uv_assignment(num_faces).squeeze(0) zero = torch.tensor([0.0, 0.0, 0.0]).to(device) one = torch.tensor([1.0, 1.0, 1.0]).to(device) for face in range(num_faces): bary = get_barycentric(uv, uv_map[face].repeat(len(uv), 1, 1)) maskA = torch.logical_and(bary[:, 0] >= 0.0, bary[:, 0] <= 1.0) maskB = torch.logical_and(bary[:, 1] >= 0.0, bary[:, 1] <= 1.0) maskC = torch.logical_and(bary[:, 2] >= 0.0, bary[:, 2] <= 1.0) mask = torch.logical_and(maskA, maskB) mask = torch.logical_and(maskC, mask) inside_triangle = bary[mask] inside_triangle_uv = inside_triangle @ uv_map[face] inside_triangle_xyz = inside_triangle @ faces_vt[face] inside_triangle_rgb = nt(inside_triangle_xyz) pixels = (inside_triangle_uv + 1.0) / 2.0 pixels = pixels * res pixels = torch.floor(pixels).type(torch.int64) image[:, pixels[:, 0], pixels[:, 1]] = inside_triangle_rgb.T return image # Get rotation matrix about vector through origin def getRotMat(axis, theta): """ axis: np.array, normalized vector theta: radians """ import math axis = axis / np.linalg.norm(axis) cprod = np.array([[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]]) rot = math.cos(theta) * np.identity(3) + math.sin(theta) * cprod + \ (1 - math.cos(theta)) * np.outer(axis, axis) return rot # Map vertices and subset of faces to 0-indexed vertices, keeping only relevant vertices def trimMesh(vertices, faces): unique_v = np.sort(np.unique(faces.flatten())) v_val = np.arange(len(unique_v)) v_map = dict(zip(unique_v, v_val)) new_faces = np.array([v_map[i] for i in faces.flatten()]).reshape(faces.shape[0], faces.shape[1]) new_v = vertices[unique_v] return new_v, new_faces # ================== VISUALIZATION ======================= # Back out camera parameters from view transform matrix def extract_from_gl_viewmat(gl_mat): gl_mat = gl_mat.reshape(4, 4) s = gl_mat[0, :3] u = gl_mat[1, :3] f = -1 * gl_mat[2, :3] coord = gl_mat[:3, 3] # first 3 entries of the last column camera_location = np.array([-s, -u, f]).T @ coord target = camera_location + f * 10 # any scale return camera_location, target def psScreenshot(vertices, faces, axis, angles, save_path, name="mesh", frame_folder="frames", scalars=None, colors=None, defined_on="faces", highlight_faces=None, highlight_color=[1, 0, 0], highlight_radius=None, cmap=None, sminmax=None, cpos=None, clook=None, save_video=False, save_base=False, ground_plane="tile_reflection", debug=False, edge_color=[0, 0, 0], edge_width=1, material=None): import polyscope as ps ps.init() # Set camera to look at same fixed position in centroid of original mesh # center = np.mean(vertices, axis = 0) # pos = center + np.array([0, 0, 3]) # ps.look_at(pos, center) ps.set_ground_plane_mode(ground_plane) frame_path = f"{save_path}/{frame_folder}" if save_base == True: ps_mesh = ps.register_surface_mesh("mesh", vertices, faces, enabled=True, edge_color=edge_color, edge_width=edge_width, material=material) ps.screenshot(f"{frame_path}/{name}.png") ps.remove_all_structures() Path(frame_path).mkdir(parents=True, exist_ok=True) # Convert 2D to 3D by appending Z-axis if vertices.shape[1] == 2: vertices = np.concatenate((vertices, np.zeros((len(vertices), 1))), axis=1) for i in range(len(angles)): rot = getRotMat(axis, angles[i]) rot_verts = np.transpose(rot @ np.transpose(vertices)) ps_mesh = ps.register_surface_mesh("mesh", rot_verts, faces, enabled=True, edge_color=edge_color, edge_width=edge_width, material=material) if scalars is not None: ps_mesh.add_scalar_quantity(f"scalar", scalars, defined_on=defined_on, cmap=cmap, enabled=True, vminmax=sminmax) if colors is not None: ps_mesh.add_color_quantity(f"color", colors, defined_on=defined_on, enabled=True) if highlight_faces is not None: # Create curve to highlight faces curve_v, new_f = trimMesh(rot_verts, faces[highlight_faces, :]) curve_edges = [] for face in new_f: curve_edges.extend( [[face[0], face[1]], [face[1], face[2]], [face[2], face[0]]]) curve_edges = np.array(curve_edges) ps_curve = ps.register_curve_network("curve", curve_v, curve_edges, color=highlight_color, radius=highlight_radius) if cpos is None or clook is None: ps.reset_camera_to_home_view() else: ps.look_at(cpos, clook) if debug == True: ps.show() ps.screenshot(f"{frame_path}/{name}_{i}.png") ps.remove_all_structures() if save_video == True: import glob from PIL import Image fp_in = f"{frame_path}/{name}_*.png" fp_out = f"{save_path}/{name}.gif" img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))] img.save(fp=fp_out, format='GIF', append_images=imgs, save_all=True, duration=200, loop=0) # ================== POSITIONAL ENCODERS ============================= class FourierFeatureTransform(torch.nn.Module): """ An implementation of Gaussian Fourier feature mapping. "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains": https://arxiv.org/abs/2006.10739 https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html Given an input of size [batches, num_input_channels, width, height], returns a tensor of size [batches, mapping_size*2, width, height]. """ def __init__(self, num_input_channels, mapping_size=256, scale=10, exclude=0): super().__init__() self._num_input_channels = num_input_channels self._mapping_size = mapping_size self.exclude = exclude B = torch.randn((num_input_channels, mapping_size)) * scale B_sort = sorted(B, key=lambda x: torch.norm(x, p=2)) self._B = torch.stack(B_sort) # for sape def forward(self, x): # assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim()) batches, channels = x.shape assert channels == self._num_input_channels, \ "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels) # Make shape compatible for matmul with _B. # From [B, C, W, H] to [(B*W*H), C]. # x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels) res = x @ self._B.to(x.device) # From [(B*W*H), C] to [B, W, H, C] # x = x.view(batches, width, height, self._mapping_size) # From [B, W, H, C] to [B, C, W, H] # x = x.permute(0, 3, 1, 2) res = 2 * np.pi * res return torch.cat([x, torch.sin(res), torch.cos(res)], dim=1) def poisson_mesh_reconstruction(points, normals=None): # points/normals: [N, 3] np.ndarray import open3d as o3d pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) # outlier removal pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10) # normals if normals is None: pcd.estimate_normals() else: pcd.normals = o3d.utility.Vector3dVector(normals[ind]) # visualize o3d.visualization.draw_geometries([pcd], point_show_normal=False) mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9) # vertices_to_remove = densities < np.quantile(densities, 0.1) # mesh.remove_vertices_by_mask(vertices_to_remove) # visualize o3d.visualization.draw_geometries([mesh]) vertices = np.asarray(mesh.vertices) triangles = np.asarray(mesh.triangles) print(f'[INFO] poisson mesh reconstruction: {points.shape} --> V: {vertices.shape} / F:{triangles.shape}') return vertices, triangles def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True): # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect. _ori_vert_shape = verts.shape _ori_face_shape = faces.shape if backend == 'pyfqmr': import pyfqmr solver = pyfqmr.Simplify() solver.setMesh(verts, faces) solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False) verts, faces, normals = solver.getMesh() else: m = pml.Mesh(verts, faces) ms = pml.MeshSet() ms.add_mesh(m, 'mesh') # will copy! # filters # ms.meshing_decimation_clustering(threshold=pml.Percentage(1)) ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement) if remesh: # ms.apply_coord_taubin_smoothing() ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1)) # extract mesh m = ms.current_mesh() verts = m.vertex_matrix() faces = m.face_matrix() print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') return verts, faces def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=False): # verts: [N, 3] # faces: [N, 3] _ori_vert_shape = verts.shape _ori_face_shape = faces.shape m = pml.Mesh(verts, faces) ms = pml.MeshSet() ms.add_mesh(m, 'mesh') # will copy! # filters ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces if v_pct > 0: ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal ms.meshing_remove_duplicate_faces() # faces defined by the same verts ms.meshing_remove_null_faces() # faces with area == 0 if min_d > 0: ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d)) if min_f > 0: ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f) if repair: # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True) ms.meshing_repair_non_manifold_edges(method=0) ms.meshing_repair_non_manifold_vertices(vertdispratio=0) if remesh: # ms.apply_coord_taubin_smoothing() ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1)) # extract mesh m = ms.current_mesh() verts = m.vertex_matrix() faces = m.face_matrix() print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') return verts, faces def laplace_regularizer_const(v_pos, t_pos_idx): term = torch.zeros_like(v_pos) norm = torch.zeros_like(v_pos[..., 0:1]) v0 = v_pos[t_pos_idx[:, 0], :] v1 = v_pos[t_pos_idx[:, 1], :] v2 = v_pos[t_pos_idx[:, 2], :] term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1, 3), (v1 - v0) + (v2 - v0)) term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1, 3), (v0 - v1) + (v2 - v1)) term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1, 3), (v0 - v2) + (v1 - v2)) two = torch.ones_like(v0) * 2.0 norm.scatter_add_(0, t_pos_idx[:, 0:1], two) norm.scatter_add_(0, t_pos_idx[:, 1:2], two) norm.scatter_add_(0, t_pos_idx[:, 2:3], two) term = term / torch.clamp(norm, min=1.0) return torch.mean(term ** 2)