Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn, FloatTensor, LongTensor, Tensor | |
import torch.nn.functional as F | |
import numpy as np | |
from torch.nn.functional import pad | |
from typing import Dict, List | |
from transformers import AutoModelForCausalLM, AutoConfig | |
import math | |
import torch_scatter | |
from flash_attn.modules.mha import MHA | |
from .spec import ModelSpec, ModelInput | |
from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder | |
from ..data.utils import linear_blend_skinning | |
class FrequencyPositionalEmbedding(nn.Module): | |
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts | |
each feature dimension of `x[..., i]` into: | |
[ | |
sin(x[..., i]), | |
sin(f_1*x[..., i]), | |
sin(f_2*x[..., i]), | |
... | |
sin(f_N * x[..., i]), | |
cos(x[..., i]), | |
cos(f_1*x[..., i]), | |
cos(f_2*x[..., i]), | |
... | |
cos(f_N * x[..., i]), | |
x[..., i] # only present if include_input is True. | |
], here f_i is the frequency. | |
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. | |
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; | |
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. | |
Args: | |
num_freqs (int): the number of frequencies, default is 6; | |
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], | |
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; | |
input_dim (int): the input dimension, default is 3; | |
include_input (bool): include the input tensor or not, default is True. | |
Attributes: | |
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], | |
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); | |
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), | |
otherwise, it is input_dim * num_freqs * 2. | |
""" | |
def __init__( | |
self, | |
num_freqs: int = 6, | |
logspace: bool = True, | |
input_dim: int = 3, | |
include_input: bool = True, | |
include_pi: bool = True, | |
) -> None: | |
"""The initialization""" | |
super().__init__() | |
if logspace: | |
frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) | |
else: | |
frequencies = torch.linspace( | |
1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 | |
) | |
if include_pi: | |
frequencies *= torch.pi | |
self.register_buffer("frequencies", frequencies, persistent=False) | |
self.include_input = include_input | |
self.num_freqs = num_freqs | |
self.out_dim = self._get_dims(input_dim) | |
def _get_dims(self, input_dim): | |
temp = 1 if self.include_input or self.num_freqs == 0 else 0 | |
out_dim = input_dim * (self.num_freqs * 2 + temp) | |
return out_dim | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Forward process. | |
Args: | |
x: tensor of shape [..., dim] | |
Returns: | |
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] | |
where temp is 1 if include_input is True and 0 otherwise. | |
""" | |
if self.num_freqs > 0: | |
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device)).view( | |
*x.shape[:-1], -1 | |
) | |
if self.include_input: | |
return torch.cat((x, embed.sin(), embed.cos()), dim=-1) | |
else: | |
return torch.cat((embed.sin(), embed.cos()), dim=-1) | |
else: | |
return x | |
class ResidualCrossAttn(nn.Module): | |
def __init__(self, feat_dim: int, num_heads: int): | |
super().__init__() | |
assert feat_dim % num_heads == 0, "feat_dim must be divisible by num_heads" | |
self.norm1 = nn.LayerNorm(feat_dim) | |
self.norm2 = nn.LayerNorm(feat_dim) | |
# self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True) | |
self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True) | |
self.ffn = nn.Sequential( | |
nn.Linear(feat_dim, feat_dim * 4), | |
nn.GELU(), | |
nn.Linear(feat_dim * 4, feat_dim), | |
) | |
def forward(self, q, kv): | |
residual = q | |
attn_output = self.attention(q, x_kv=kv) | |
x = self.norm1(residual + attn_output) | |
x = self.norm2(x + self.ffn(x)) | |
return x | |
class BoneEncoder(nn.Module): | |
def __init__( | |
self, | |
feat_bone_dim: int, | |
feat_dim: int, | |
embed_dim: int, | |
num_heads: int, | |
num_attn: int, | |
): | |
super().__init__() | |
self.feat_bone_dim = feat_bone_dim | |
self.feat_dim = feat_dim | |
self.num_heads = num_heads | |
self.num_attn = num_attn | |
self.position_embed = FrequencyPositionalEmbedding(input_dim=self.feat_bone_dim) | |
self.bone_encoder = nn.Sequential( | |
self.position_embed, | |
nn.Linear(self.position_embed.out_dim, embed_dim), | |
nn.LayerNorm(embed_dim), | |
nn.GELU(), | |
nn.Linear(embed_dim, embed_dim * 4), | |
nn.LayerNorm(embed_dim * 4), | |
nn.GELU(), | |
nn.Linear(embed_dim * 4, feat_dim), | |
nn.LayerNorm(feat_dim), | |
nn.GELU(), | |
) | |
self.attn = nn.ModuleList() | |
for _ in range(self.num_attn): | |
self.attn.append(ResidualCrossAttn(feat_dim, self.num_heads)) | |
def forward( | |
self, | |
base_bone: FloatTensor, | |
num_bones: LongTensor, | |
parents: LongTensor, | |
min_coord: FloatTensor, | |
global_latents: FloatTensor, | |
): | |
# base_bone: (B, J, C) | |
B = base_bone.shape[0] | |
J = base_bone.shape[1] | |
x = self.bone_encoder((base_bone-min_coord[:, None, :]).reshape(-1, base_bone.shape[-1])).reshape(B, J, -1) | |
latents = torch.cat([x, global_latents], dim=1) | |
for (i, attn) in enumerate(self.attn): | |
x = attn(x, latents) | |
return x | |
class SkinweightPred(nn.Module): | |
def __init__(self, in_dim, mlp_dim): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(in_dim, mlp_dim), | |
nn.LayerNorm(mlp_dim), | |
nn.GELU(), | |
nn.Linear(mlp_dim, mlp_dim), | |
nn.LayerNorm(mlp_dim), | |
nn.GELU(), | |
nn.Linear(mlp_dim, mlp_dim), | |
nn.LayerNorm(mlp_dim), | |
nn.GELU(), | |
nn.Linear(mlp_dim, mlp_dim), | |
nn.LayerNorm(mlp_dim), | |
nn.GELU(), | |
nn.Linear(mlp_dim, 1), | |
) | |
def forward(self, x): | |
return self.net(x) | |
class UniRigSkin(ModelSpec): | |
def process_fn(self, batch: List[ModelInput]) -> List[Dict]: | |
max_bones = 0 | |
for b in batch: | |
max_bones = max(max_bones, b.asset.J) | |
res = [] | |
current_offset = 0 | |
for b in batch: | |
vertex_groups = b.asset.sampled_vertex_groups | |
current_offset += b.vertices.shape[0] | |
# (N, J) | |
voxel_skin = vertex_groups['voxel_skin'] | |
voxel_skin = np.pad(voxel_skin, ((0, 0), (0, max_bones-b.asset.J)), 'constant', constant_values=0.0) | |
# (J, 4, 4) | |
res.append({ | |
'voxel_skin': voxel_skin, | |
'offset': current_offset, | |
}) | |
return res | |
def __init__(self, mesh_encoder, global_encoder, **kwargs): | |
super().__init__() | |
self.num_train_vertex = kwargs['num_train_vertex'] | |
self.feat_dim = kwargs['feat_dim'] | |
self.num_heads = kwargs['num_heads'] | |
self.grid_size = kwargs['grid_size'] | |
self.mlp_dim = kwargs['mlp_dim'] | |
self.num_bone_attn = kwargs['num_bone_attn'] | |
self.num_mesh_bone_attn = kwargs['num_mesh_bone_attn'] | |
self.bone_embed_dim = kwargs['bone_embed_dim'] | |
self.voxel_mask = kwargs.get('voxel_mask', 2) | |
self.mesh_encoder = get_mesh_encoder(**mesh_encoder) | |
self.global_encoder = get_mesh_encoder(**global_encoder) | |
if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj): | |
self.feat_map = nn.Sequential( | |
nn.Linear(mesh_encoder['enc_channels'][-1], self.feat_dim), | |
nn.LayerNorm(self.feat_dim), | |
nn.GELU(), | |
) | |
else: | |
raise NotImplementedError() | |
if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder): | |
self.out_proj = nn.Sequential( | |
nn.Linear(self.global_encoder.width, self.feat_dim), | |
nn.LayerNorm(self.feat_dim), | |
nn.GELU(), | |
) | |
else: | |
raise NotImplementedError() | |
self.bone_encoder = BoneEncoder( | |
feat_bone_dim=3, | |
feat_dim=self.feat_dim, | |
embed_dim=self.bone_embed_dim, | |
num_heads=self.num_heads, | |
num_attn=self.num_bone_attn, | |
) | |
self.downscale = nn.Sequential( | |
nn.Linear(2 * self.num_heads, self.num_heads), | |
nn.LayerNorm(self.num_heads), | |
nn.GELU(), | |
) | |
self.skinweight_pred = SkinweightPred( | |
self.num_heads, | |
self.mlp_dim, | |
) | |
self.mesh_bone_attn = nn.ModuleList() | |
self.mesh_bone_attn.extend([ | |
ResidualCrossAttn(self.feat_dim, self.num_heads) for _ in range(self.num_mesh_bone_attn) | |
]) | |
self.qmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads) | |
self.kmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads) | |
self.voxel_skin_embed = nn.Linear(1, self.num_heads) | |
self.voxel_skin_norm = nn.LayerNorm(self.num_heads) | |
self.attn_skin_norm = nn.LayerNorm(self.num_heads) | |
def encode_mesh_cond(self, vertices: FloatTensor, normals: FloatTensor) -> FloatTensor: | |
assert not torch.isnan(vertices).any() | |
assert not torch.isnan(normals).any() | |
if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder): | |
if (len(vertices.shape) == 3): | |
shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices, feats=normals) | |
else: | |
shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0)) | |
latents = self.out_proj(latents) | |
return latents | |
else: | |
raise NotImplementedError() | |
def _get_predict(self, batch: Dict) -> FloatTensor: | |
''' | |
Return predicted skin. | |
''' | |
num_bones: Tensor = batch['num_bones'] | |
vertices: FloatTensor = batch['vertices'] # (B, N, 3) | |
normals: FloatTensor = batch['normals'] | |
joints: FloatTensor = batch['joints'] | |
tails: FloatTensor = batch['tails'] | |
voxel_skin: FloatTensor = batch['voxel_skin'] | |
parents: LongTensor = batch['parents'] | |
# turn inputs' dtype into model's dtype | |
dtype = next(self.parameters()).dtype | |
vertices = vertices.type(dtype) | |
normals = normals.type(dtype) | |
joints = joints.type(dtype) | |
tails = tails.type(dtype) | |
voxel_skin = voxel_skin.type(dtype) | |
B = vertices.shape[0] | |
N = vertices.shape[1] | |
J = joints.shape[1] | |
assert vertices.dim() == 3 | |
assert normals.dim() == 3 | |
part_offset = torch.tensor([(i+1)*N for i in range(B)], dtype=torch.int64, device=vertices.device) | |
idx_ptr = torch.nn.functional.pad(part_offset, (1, 0), value=0) | |
min_coord = torch_scatter.segment_csr(vertices.reshape(-1, 3), idx_ptr, reduce="min") | |
pack = [] | |
if self.training: | |
train_indices = torch.randperm(N)[:self.num_train_vertex] | |
pack.append(train_indices) | |
else: | |
for i in range((N + self.num_train_vertex - 1) // self.num_train_vertex): | |
pack.append(torch.arange(i*self.num_train_vertex, min((i+1)*self.num_train_vertex, N))) | |
# (B, seq_len, feat_dim) | |
global_latents = self.encode_mesh_cond(vertices, normals) | |
bone_feat = self.bone_encoder( | |
base_bone=joints, | |
num_bones=num_bones, | |
parents=parents, | |
min_coord=min_coord, | |
global_latents=global_latents, | |
) | |
if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj): | |
feat = torch.cat([vertices, normals, torch.zeros_like(vertices)], dim=-1) | |
ptv3_input = { | |
'coord': vertices.reshape(-1, 3), | |
'feat': feat.reshape(-1, 9), | |
'offset': torch.tensor(batch['offset']), | |
'grid_size': self.grid_size, | |
} | |
if not self.training: | |
# must cast to float32 to avoid sparse-conv precision bugs | |
with torch.autocast(device_type='cuda', dtype=torch.float32): | |
mesh_feat = self.mesh_encoder(ptv3_input).feat | |
mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim) | |
else: | |
mesh_feat = self.mesh_encoder(ptv3_input).feat | |
mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim) | |
mesh_feat = mesh_feat.type(dtype) | |
else: | |
raise NotImplementedError() | |
# (B, J + seq_len, feat_dim) | |
latents = torch.cat([bone_feat, global_latents], dim=1) | |
# (B, N, feat_dim) | |
for block in self.mesh_bone_attn: | |
mesh_feat = block( | |
q=mesh_feat, | |
kv=latents, | |
) | |
# trans to (B, num_heads, J, feat_dim) | |
bone_feat = self.kmesh(bone_feat).view(B, J, self.num_heads, self.feat_dim).transpose(1, 2) | |
skin_pred_list = [] | |
if not self.training: | |
skin_mask = voxel_skin.clone() | |
for b in range(B): | |
num = num_bones[b] | |
for i in range(num): | |
p = parents[b, i] | |
if p < 0: | |
continue | |
skin_mask[b, :, p] += skin_mask[b, :, i] | |
for indices in pack: | |
cur_N = len(indices) | |
# trans to (B, num_heads, N, feat_dim) | |
cur_mesh_feat = self.qmesh(mesh_feat[:, indices]).view(B, cur_N, self.num_heads, self.feat_dim).transpose(1, 2) | |
# attn_weight shape : (B, num_heads, N, J) | |
attn_weight = F.softmax(torch.bmm( | |
cur_mesh_feat.reshape(B * self.num_heads, cur_N, -1), | |
bone_feat.transpose(-2, -1).reshape(B * self.num_heads, -1, J) | |
) / math.sqrt(self.feat_dim), dim=-1, dtype=dtype) | |
# (B, num_heads, N, J) -> (B, N, J, num_heads) | |
attn_weight = attn_weight.reshape(B, self.num_heads, cur_N, J).permute(0, 2, 3, 1) | |
attn_weight = self.attn_skin_norm(attn_weight) | |
embed_voxel_skin = self.voxel_skin_embed(voxel_skin[:, indices].reshape(B, cur_N, J, 1)) | |
embed_voxel_skin = self.voxel_skin_norm(embed_voxel_skin) | |
attn_weight = torch.cat([attn_weight, embed_voxel_skin], dim=-1) | |
attn_weight = self.downscale(attn_weight) | |
# (B, N, J, num_heads * (1+c)) -> (B, N, J) | |
skin_pred = torch.zeros(B, cur_N, J).to(attn_weight.device, dtype) | |
for i in range(B): | |
# (N*J, C) | |
input_features = attn_weight[i, :, :num_bones[i], :].reshape(-1, attn_weight.shape[-1]) | |
pred = self.skinweight_pred(input_features).reshape(cur_N, num_bones[i]) | |
skin_pred[i, :, :num_bones[i]] = F.softmax(pred) | |
skin_pred_list.append(skin_pred) | |
skin_pred_list = torch.cat(skin_pred_list, dim=1) | |
for i in range(B): | |
n = num_bones[i] | |
skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] * torch.pow(skin_mask[i, :, :n], self.voxel_mask) | |
skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] / skin_pred_list[i, :, :n].sum(dim=-1, keepdim=True) | |
return skin_pred_list, torch.cat(pack, dim=0) | |
def predict_step(self, batch: Dict): | |
with torch.no_grad(): | |
num_bones: Tensor = batch['num_bones'] | |
skin_pred, _ = self._get_predict(batch=batch) | |
outputs = [] | |
for i in range(skin_pred.shape[0]): | |
outputs.append(skin_pred[i, :, :num_bones[i]]) | |
return outputs |