Spaces:
Runtime error
Runtime error
import torch | |
def linear_blend_skinning(points, weight, joint_transform, return_vT=False, inverse=False): | |
""" | |
Args: | |
points: FloatTensor [batch, N, 3] | |
weight: FloatTensor [batch, N, K] | |
joint_transform: FloatTensor [batch, K, 4, 4] | |
return_vT: return vertex transform matrix if true | |
inverse: bool inverse LBS if true | |
Return: | |
points_deformed: FloatTensor [batch, N, 3] | |
""" | |
if not weight.shape[0] == joint_transform.shape[0]: | |
raise AssertionError('batch should be same,', weight.shape, joint_transform.shape) | |
if not torch.is_tensor(points): | |
points = torch.as_tensor(points).float() | |
if not torch.is_tensor(weight): | |
weight = torch.as_tensor(weight).float() | |
if not torch.is_tensor(joint_transform): | |
joint_transform = torch.as_tensor(joint_transform).float() | |
batch = joint_transform.size(0) | |
vT = torch.bmm(weight, joint_transform.contiguous().view(batch, -1, 16)).view(batch, -1, 4, 4) | |
if inverse: | |
vT = torch.inverse(vT.view(-1, 4, 4)).view(batch, -1, 4, 4) | |
R, T = vT[:, :, :3, :3], vT[:, :, :3, 3] | |
deformed_points = torch.matmul(R, points.unsqueeze(-1)).squeeze(-1) + T | |
if return_vT: | |
return deformed_points, vT | |
return deformed_points | |
def warp_points(points, skin_weights, joint_transform, inverse=False): | |
""" | |
Warp a canonical point cloud to multiple posed spaces and project to image space | |
Args: | |
points: [N, 3] Tensor of 3D points | |
skin_weights: [N, J] corresponding skinning weights of points | |
joint_transform: [B, J, 4, 4] joint transform matrix of a batch of poses | |
Returns: | |
posed_points [B, N, 3] warpped points in posed space | |
""" | |
if not torch.is_tensor(points): | |
points = torch.as_tensor(points).float() | |
if not torch.is_tensor(joint_transform): | |
joint_transform = torch.as_tensor(joint_transform).float() | |
if not torch.is_tensor(skin_weights): | |
skin_weights = torch.as_tensor(skin_weights).float() | |
batch = joint_transform.shape[0] | |
if points.dim() == 2: | |
points = points.expand(batch, -1, -1) | |
# warping | |
points_posed, vT = linear_blend_skinning(points, | |
skin_weights.expand(batch, -1, -1), | |
joint_transform, return_vT=True, inverse=inverse) | |
return points_posed | |