kleinhe
init
c3d0293
import torch
def linear_blend_skinning(points, weight, joint_transform, return_vT=False, inverse=False):
"""
Args:
points: FloatTensor [batch, N, 3]
weight: FloatTensor [batch, N, K]
joint_transform: FloatTensor [batch, K, 4, 4]
return_vT: return vertex transform matrix if true
inverse: bool inverse LBS if true
Return:
points_deformed: FloatTensor [batch, N, 3]
"""
if not weight.shape[0] == joint_transform.shape[0]:
raise AssertionError('batch should be same,', weight.shape, joint_transform.shape)
if not torch.is_tensor(points):
points = torch.as_tensor(points).float()
if not torch.is_tensor(weight):
weight = torch.as_tensor(weight).float()
if not torch.is_tensor(joint_transform):
joint_transform = torch.as_tensor(joint_transform).float()
batch = joint_transform.size(0)
vT = torch.bmm(weight, joint_transform.contiguous().view(batch, -1, 16)).view(batch, -1, 4, 4)
if inverse:
vT = torch.inverse(vT.view(-1, 4, 4)).view(batch, -1, 4, 4)
R, T = vT[:, :, :3, :3], vT[:, :, :3, 3]
deformed_points = torch.matmul(R, points.unsqueeze(-1)).squeeze(-1) + T
if return_vT:
return deformed_points, vT
return deformed_points
def warp_points(points, skin_weights, joint_transform, inverse=False):
"""
Warp a canonical point cloud to multiple posed spaces and project to image space
Args:
points: [N, 3] Tensor of 3D points
skin_weights: [N, J] corresponding skinning weights of points
joint_transform: [B, J, 4, 4] joint transform matrix of a batch of poses
Returns:
posed_points [B, N, 3] warpped points in posed space
"""
if not torch.is_tensor(points):
points = torch.as_tensor(points).float()
if not torch.is_tensor(joint_transform):
joint_transform = torch.as_tensor(joint_transform).float()
if not torch.is_tensor(skin_weights):
skin_weights = torch.as_tensor(skin_weights).float()
batch = joint_transform.shape[0]
if points.dim() == 2:
points = points.expand(batch, -1, -1)
# warping
points_posed, vT = linear_blend_skinning(points,
skin_weights.expand(batch, -1, -1),
joint_transform, return_vT=True, inverse=inverse)
return points_posed