import torch def linear_blend_skinning(points, weight, joint_transform, return_vT=False, inverse=False): """ Args: points: FloatTensor [batch, N, 3] weight: FloatTensor [batch, N, K] joint_transform: FloatTensor [batch, K, 4, 4] return_vT: return vertex transform matrix if true inverse: bool inverse LBS if true Return: points_deformed: FloatTensor [batch, N, 3] """ if not weight.shape[0] == joint_transform.shape[0]: raise AssertionError('batch should be same,', weight.shape, joint_transform.shape) if not torch.is_tensor(points): points = torch.as_tensor(points).float() if not torch.is_tensor(weight): weight = torch.as_tensor(weight).float() if not torch.is_tensor(joint_transform): joint_transform = torch.as_tensor(joint_transform).float() batch = joint_transform.size(0) vT = torch.bmm(weight, joint_transform.contiguous().view(batch, -1, 16)).view(batch, -1, 4, 4) if inverse: vT = torch.inverse(vT.view(-1, 4, 4)).view(batch, -1, 4, 4) R, T = vT[:, :, :3, :3], vT[:, :, :3, 3] deformed_points = torch.matmul(R, points.unsqueeze(-1)).squeeze(-1) + T if return_vT: return deformed_points, vT return deformed_points def warp_points(points, skin_weights, joint_transform, inverse=False): """ Warp a canonical point cloud to multiple posed spaces and project to image space Args: points: [N, 3] Tensor of 3D points skin_weights: [N, J] corresponding skinning weights of points joint_transform: [B, J, 4, 4] joint transform matrix of a batch of poses Returns: posed_points [B, N, 3] warpped points in posed space """ if not torch.is_tensor(points): points = torch.as_tensor(points).float() if not torch.is_tensor(joint_transform): joint_transform = torch.as_tensor(joint_transform).float() if not torch.is_tensor(skin_weights): skin_weights = torch.as_tensor(skin_weights).float() batch = joint_transform.shape[0] if points.dim() == 2: points = points.expand(batch, -1, -1) # warping points_posed, vT = linear_blend_skinning(points, skin_weights.expand(batch, -1, -1), joint_transform, return_vT=True, inverse=inverse) return points_posed