from abc import ABC, abstractmethod from dataclasses import dataclass import numpy as np from numpy import ndarray from typing import Dict, Union, List, final import lightning.pytorch as pl from ..data.asset import Asset from ..data.augment import Augment @dataclass class ModelInput(): # tokens for ar input tokens: Union[ndarray, None]=None # pad token pad: Union[int, None]=None # vertices(usually sampled), (N, 3) vertices: Union[ndarray, None]=None # normals(usually sampled), (N, 3) normals: Union[ndarray, None]=None # joints joints: Union[ndarray, None]=None # tails tails: Union[ndarray, None]=None # assets for debug usage asset: Union[Asset, None]=None # augments asset used augments: Union[Augment, None]=None class ModelSpec(pl.LightningModule, ABC): @abstractmethod def __init__(self): super().__init__() @final def _process_fn(self, batch: List[ModelInput]) -> List[Dict]: ''' Returns cls: List[str] path: List[str] data_name: List[str] joints: shape (B, J, 3), J==max_bones tails: shape (B, J, 3) parents: shape (B, J), -1 represents no parent(should always appear at 0-th position) num_bones: shape (B), the true number of bones skin: shape (B, J), padding value==0. vertices: (B, N, 3) normals: (B, N, 3) matrix_local: (B, J, 4, 4), current matrix_local pose_matrix: (B, J, 4, 4), for motion loss calculation ''' n_batch = self.process_fn(batch) BAN = ['cls', 'path', 'data_name', 'joints', 'tails', 'parents', 'num_bones', 'vertices', 'normals', 'matrix_local', 'pose_matrix', 'num_points', 'origin_vertices', 'origin_vertex_normals', 'origin_face_normals', 'num_faces', 'faces'] # skin should be in vertex group max_bones = 0 max_points = 0 max_faces = 0 for b in batch: if b.joints is not None: max_bones = max(max_bones, b.asset.J) max_faces = max(max_faces, b.asset.F) max_points = max(max_points, b.asset.N) self._augments = [] self._assets = [] for (id, b) in enumerate(batch): for ban in BAN: assert ban not in n_batch[id], f"cannot override `{ban}` in process_fn" n_batch[id]['cls'] = b.asset.cls n_batch[id]['path'] = b.asset.path n_batch[id]['data_name'] = b.asset.data_name if b.asset.joints is not None: n_batch[id]['joints'] = np.pad(b.asset.joints, ((0, max_bones-b.asset.J), (0, 0)), mode='constant', constant_values=0.) n_batch[id]['num_bones'] = b.asset.J if b.asset.tails is not None: n_batch[id]['tails'] = np.pad(b.asset.tails, ((0, max_bones-b.asset.J), (0, 0)), mode='constant', constant_values=0.) if b.asset.parents is not None: parents = b.asset.parents.copy() # cannot put None into dict parents[0] = -1 parents = np.pad(parents, (0, max_bones-b.asset.J), 'constant', constant_values=-1) n_batch[id]['parents'] = parents if b.asset.matrix_local is not None: J = b.asset.J matrix_local = np.pad(b.asset.matrix_local, ((0, max_bones-J), (0, 0), (0, 0)), 'constant', constant_values=0.) # set identity to prevent singular matrix in lbs matrix_local[J:, 0, 0] = 1. matrix_local[J:, 1, 1] = 1. matrix_local[J:, 2, 2] = 1. matrix_local[J:, 3, 3] = 1. n_batch[id]['matrix_local'] = matrix_local if b.asset.pose_matrix is not None: J = b.asset.J pose_matrix = np.pad(b.asset.pose_matrix, ((0, max_bones-J), (0, 0), (0, 0)), 'constant', constant_values=0.) pose_matrix[J:, 0, 0] = 1. pose_matrix[J:, 1, 1] = 1. pose_matrix[J:, 2, 2] = 1. pose_matrix[J:, 3, 3] = 1. n_batch[id]['pose_matrix'] = pose_matrix n_batch[id]['vertices'] = b.vertices n_batch[id]['normals'] = b.normals n_batch[id]['num_points'] = b.asset.N n_batch[id]['origin_vertices'] = np.pad(b.asset.vertices, ((0, max_points-b.asset.N), (0, 0))) n_batch[id]['origin_vertex_normals'] = np.pad(b.asset.vertex_normals, ((0, max_points-b.asset.N), (0, 0))) n_batch[id]['num_faces'] = b.asset.F n_batch[id]['origin_faces'] = np.pad(b.asset.faces, ((0, max_faces-b.asset.F), (0, 0))) n_batch[id]['origin_face_normals'] = np.pad(b.asset.face_normals, ((0, max_faces-b.asset.F), (0, 0))) return n_batch @abstractmethod def process_fn(self, batch: List[ModelInput]) -> Dict: ''' Fetch data from dataloader and turn it into Tensor objects. ''' pass