diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c55fb47be2a0466a397746f84456655a7794143f --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +body_models +results +weights +tada-extend \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..3a5e77e15ceaf521b112fcb095272450ded3fe60 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 xin he + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/SMPLX/__pycache__/joints2smpl.cpython-310.pyc b/SMPLX/__pycache__/joints2smpl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e8ffd7951e7757ba6cd122a67a1b148b7f74cad Binary files /dev/null and b/SMPLX/__pycache__/joints2smpl.cpython-310.pyc differ diff --git a/SMPLX/__pycache__/joints2smpl.cpython-39.pyc b/SMPLX/__pycache__/joints2smpl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aac5e3cdcc349b88bff8503727d48dbf268ffc70 Binary files /dev/null and b/SMPLX/__pycache__/joints2smpl.cpython-39.pyc differ diff --git a/SMPLX/__pycache__/read_from_npy.cpython-310.pyc b/SMPLX/__pycache__/read_from_npy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19bad8d4c041d593fd78dafd0b5c6d13cf60a5a2 Binary files /dev/null and b/SMPLX/__pycache__/read_from_npy.cpython-310.pyc differ diff --git a/SMPLX/__pycache__/read_from_npy.cpython-311.pyc b/SMPLX/__pycache__/read_from_npy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70a3202cc160420db82442152ef0c70e34def89f Binary files /dev/null and b/SMPLX/__pycache__/read_from_npy.cpython-311.pyc differ diff --git a/SMPLX/__pycache__/read_from_npy.cpython-39.pyc b/SMPLX/__pycache__/read_from_npy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aed1567a85adc7789cd6861283f113a9ba62800e Binary files /dev/null and b/SMPLX/__pycache__/read_from_npy.cpython-39.pyc differ diff --git a/SMPLX/__pycache__/read_joints_from_pose.cpython-39.pyc b/SMPLX/__pycache__/read_joints_from_pose.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d694b1796e715889e8230083175030d2d42d78b3 Binary files /dev/null and b/SMPLX/__pycache__/read_joints_from_pose.cpython-39.pyc differ diff --git a/SMPLX/__pycache__/rotation_conversions.cpython-310.pyc b/SMPLX/__pycache__/rotation_conversions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93bc0e4243a02af2c4cec303b7588c9f99eb7e1c Binary files /dev/null and b/SMPLX/__pycache__/rotation_conversions.cpython-310.pyc differ diff --git a/SMPLX/__pycache__/rotation_conversions.cpython-311.pyc b/SMPLX/__pycache__/rotation_conversions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34df4b8d96156ce4116ab63c17fe76febbbee9d3 Binary files /dev/null and b/SMPLX/__pycache__/rotation_conversions.cpython-311.pyc differ diff --git a/SMPLX/__pycache__/rotation_conversions.cpython-39.pyc b/SMPLX/__pycache__/rotation_conversions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c63bd5c35a93d92f1f1671cde2cd2e721f2bc84d Binary files /dev/null and b/SMPLX/__pycache__/rotation_conversions.cpython-39.pyc differ diff --git a/SMPLX/__pycache__/transfer_smpls.cpython-39.pyc b/SMPLX/__pycache__/transfer_smpls.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..985c617cfa29311721eccfad93da7725f4e0d155 Binary files /dev/null and b/SMPLX/__pycache__/transfer_smpls.cpython-39.pyc differ diff --git a/SMPLX/__pycache__/visual_amass.cpython-39.pyc b/SMPLX/__pycache__/visual_amass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e8f8da623b0247d5c65ca14f453d90b321ca5ea Binary files /dev/null and b/SMPLX/__pycache__/visual_amass.cpython-39.pyc differ diff --git a/SMPLX/__pycache__/visualize.cpython-38.pyc b/SMPLX/__pycache__/visualize.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74d10aac1d60b794bb7d481c43e9aa9c41c085ba Binary files /dev/null and b/SMPLX/__pycache__/visualize.cpython-38.pyc differ diff --git a/SMPLX/config_files/smpl2smplh.yaml b/SMPLX/config_files/smpl2smplh.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98cc874831b65a71dcb15f4f4e6018841e80c08b --- /dev/null +++ b/SMPLX/config_files/smpl2smplh.yaml @@ -0,0 +1,25 @@ +datasets: + mesh_folder: + data_folder: 'transfer_data/meshes/smpl' +deformation_transfer_path: 'transfer_data/smpl2smplh_def_transfer.pkl' +mask_ids_fname: '' +summary_steps: 100 + +edge_fitting: + per_part: False + +optim: + type: 'trust-ncg' + maxiters: 100 + gtol: 1e-06 + +body_model: + model_type: "smplh" + # SMPL+H has no neutral model, so we have to manually select the gender + gender: "female" + # gender: "male" + folder: "transfer_data/body_models" + use_compressed: False + smplh: + betas: + num: 10 diff --git a/SMPLX/config_files/smpl2smplx.yaml b/SMPLX/config_files/smpl2smplx.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aad7ac50b203a9efc5182025fc88f3c6f4fa43f3 --- /dev/null +++ b/SMPLX/config_files/smpl2smplx.yaml @@ -0,0 +1,26 @@ +datasets: + mesh_folder: + data_folder: 'transfer_data/meshes/smpl' +deformation_transfer_path: 'transfer_data/smpl2smplx_deftrafo_setup.pkl' +mask_ids_fname: 'smplx_mask_ids.npy' +summary_steps: 100 + +edge_fitting: + per_part: False + +optim: + type: 'trust-ncg' + maxiters: 100 + gtol: 1e-06 + +body_model: + model_type: "smplx" + gender: "neutral" + folder: "transfer_data/body_models" + use_compressed: False + use_face_contour: True + smplx: + betas: + num: 10 + expression: + num: 10 diff --git a/SMPLX/config_files/smplh2smpl.yaml b/SMPLX/config_files/smplh2smpl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..223d33736f9ea74a6b3450073008b0488ac3ad52 --- /dev/null +++ b/SMPLX/config_files/smplh2smpl.yaml @@ -0,0 +1,24 @@ +datasets: + mesh_folder: + data_folder: 'transfer_data/meshes/smplh' +deformation_transfer_path: 'transfer_data/smplh2smpl_def_transfer.pkl' +mask_ids_fname: '' +summary_steps: 100 + +edge_fitting: + per_part: False + +optim: + type: 'trust-ncg' + maxiters: 100 + gtol: 1e-06 + +body_model: + model_type: "smpl" + gender: "neutral" + folder: "transfer_data/body_models" + use_compressed: False + use_face_contour: True + smpl: + betas: + num: 10 diff --git a/SMPLX/config_files/smplh2smplx.yaml b/SMPLX/config_files/smplh2smplx.yaml new file mode 100644 index 0000000000000000000000000000000000000000..682d0e665dc084b72e936be43371c0c9a3df299c --- /dev/null +++ b/SMPLX/config_files/smplh2smplx.yaml @@ -0,0 +1,26 @@ +datasets: + mesh_folder: + data_folder: 'transfer_data/meshes/smplh' +deformation_transfer_path: 'transfer_data/smplh2smplx_deftrafo_setup.pkl' +mask_ids_fname: 'smplx_mask_ids.npy' +summary_steps: 100 + +edge_fitting: + per_part: False + +optim: + type: 'trust-ncg' + maxiters: 100 + gtol: 1e-06 + +body_model: + model_type: "smplx" + gender: "neutral" + folder: "transfer_data/body_models" + use_compressed: False + use_face_contour: True + smplx: + betas: + num: 10 + expression: + num: 10 diff --git a/SMPLX/config_files/smplh2smplx_as.yaml b/SMPLX/config_files/smplh2smplx_as.yaml new file mode 100644 index 0000000000000000000000000000000000000000..422eeaa86fb424ba9607cae6ddf4ae43e421c2d0 --- /dev/null +++ b/SMPLX/config_files/smplh2smplx_as.yaml @@ -0,0 +1,26 @@ +datasets: + mesh_folder: + data_folder: 'transfer_data/meshes/amass_sample' +deformation_transfer_path: 'transfer_data/smplh2smplx_deftrafo_setup.pkl' +mask_ids_fname: 'smplx_mask_ids.npy' +summary_steps: 100 + +edge_fitting: + per_part: False + +optim: + type: 'trust-ncg' + maxiters: 100 + gtol: 1e-06 + +body_model: + model_type: "smplx" + gender: "male" + folder: "/data/TTA/data/body_models" + use_compressed: False + use_face_contour: True + smplx: + betas: + num: 10 + expression: + num: 10 diff --git a/SMPLX/config_files/smplh2smplx_onepose.yaml b/SMPLX/config_files/smplh2smplx_onepose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a592b83bbb5e34099d1705927ead4c9b88011d5 --- /dev/null +++ b/SMPLX/config_files/smplh2smplx_onepose.yaml @@ -0,0 +1,27 @@ +datasets: + mesh_folder: + data_folder: 'transfer_data/meshes/amass_onepose' +deformation_transfer_path: 'transfer_data/smplh2smplx_deftrafo_setup.pkl' +mask_ids_fname: 'smplx_mask_ids.npy' +summary_steps: 100 + +edge_fitting: + per_part: False + +optim: + type: 'adam' + lr: 0.1 + maxiters: 10000 + gtol: 1e-06 + +body_model: + model_type: "smplx" + gender: "neutral" + folder: "models" + use_compressed: False + use_face_contour: True + smplx: + betas: + num: 10 + expression: + num: 10 diff --git a/SMPLX/config_files/smplx2smpl.yaml b/SMPLX/config_files/smplx2smpl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7012fefda9ac367df616bdb6cb6f72567516e86a --- /dev/null +++ b/SMPLX/config_files/smplx2smpl.yaml @@ -0,0 +1,25 @@ +datasets: + mesh_folder: + data_folder: 'meshes/smplx' +deformation_transfer_path: 'transfer_data/smplx2smpl_deftrafo_setup.pkl' +mask_ids_fname: '' +summary_steps: 100 + +edge_fitting: + per_part: False + +optim: + type: 'lbfgs' + maxiters: 200 + gtol: 1e-06 + +body_model: + model_type: "smpl" + gender: "neutral" + ext: 'pkl' + folder: "transfer_data/body_models" + use_compressed: False + use_face_contour: True + smpl: + betas: + num: 10 diff --git a/SMPLX/config_files/smplx2smplh.yaml b/SMPLX/config_files/smplx2smplh.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76275e011880f1ede64329f253bf236ab423abde --- /dev/null +++ b/SMPLX/config_files/smplx2smplh.yaml @@ -0,0 +1,27 @@ +datasets: + mesh_folder: + data_folder: 'meshes/smplx' +deformation_transfer_path: 'transfer_data/smplx2smplh_deftrafo_setup.pkl' +mask_ids_fname: '' +summary_steps: 100 + +edge_fitting: + per_part: False + +optim: + type: 'lbfgs' + maxiters: 200 + gtol: 1e-06 + +body_model: + model_type: "smplh" + # SMPL+H has no neutral model, so we have to manually select the gender + gender: "female" + # gender: "male" + ext: 'pkl' + folder: "transfer_data/body_models" + use_compressed: False + use_face_contour: True + smplh: + betas: + num: 10 diff --git a/SMPLX/joints2smpl.py b/SMPLX/joints2smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..2419e3ec518a1bca9b032502a4ccf372d195a0a2 --- /dev/null +++ b/SMPLX/joints2smpl.py @@ -0,0 +1,59 @@ +import torch +from SMPLX.visualize_joint2smpl.simplify_loc2rot import joints2smpl +import argparse +import numpy as np +import os +from tqdm import tqdm + +parser = argparse.ArgumentParser(description='transfer joint3d to smpls') +parser.add_argument("--model_path", default="/data/TTA/data/body_models") +parser.add_argument('--source_path', default="/data/TTA/data/humanact12/group_000") +parser.add_argument("--target_path", default="/data/TTA/data/humanact_smplh/group_000") +parser.add_argument("--mode", default="joints", choices=["t2m", "joints"]) +args = parser.parse_args() +device = "cuda" + +if os.path.isdir(args.source_path): + os.makedirs(args.target_path, exist_ok=True) + files = os.listdir(args.source_path) + target_files = files +else: + files = [args.source_path] + args.source_path = "" + + if args.target_path.split(".")[-1] != "npy": + os.makedirs(args.target_path) + target_files = [files[0].split("/")[-1]] + else: + target_files = [args.target_path] + args.target_path = "" + +for i in range(len(files)): + curr_path = os.path.join(args.source_path, files[i]) + target_path = os.path.join(args.target_path, target_files[i]) + if os.path.exists(target_path): + continue + + curr_file = np.load(curr_path) #### [nframe, 263] + curr_file = torch.from_numpy(curr_file) + + if args.mode == "t2m": + from dataset.t2m.recover_joints import recover_from_ric + motions = recover_from_ric(curr_file, 22) #### [nframes, 22, 3] + motions = motions.detach().cpu().numpy() + else: + motions = curr_file.detach().cpu().numpy() + + frames, njoints, nfeats = motions.shape + MINS = motions.min(axis=0).min(axis=0) + MAXS = motions.max(axis=0).max(axis=0) + height_offset = MINS[1] + motions[:, :, 1] -= height_offset + model = joints2smpl(frames, 0, True, model_path=args.model_path) + target, trans = model.joint2smpl(motions) + + target = np.concatenate([target, trans], axis=1) + + np.save(target_path, target) + if i % 10 == 0: + print("save %d npys"%(i)) \ No newline at end of file diff --git a/SMPLX/read_from_npy.py b/SMPLX/read_from_npy.py new file mode 100644 index 0000000000000000000000000000000000000000..438061a3ee2b47c112d8e9fbbd0f534e3d007e08 --- /dev/null +++ b/SMPLX/read_from_npy.py @@ -0,0 +1,108 @@ +import numpy as np +import torch + +def npy2info(motions, num_shapes=10): + if isinstance(motions, str): + motions = np.load(motions) + + trans = None + gnum = 2 + + if isinstance(motions, np.ndarray): + betas = np.zeros([motions.shape[0], num_shapes]).astype(motions.dtype) + else: + betas = torch.zeros([motions.shape[0], num_shapes], dtype=motions.dtype) + + if len(motions.shape) == 3: + motions = motions.reshape(motions.shape[0], -1) + + if motions.shape[1] in [73, 157, 166]: + gnum = motions[:, -1:][0] + motions = motions[:, :-1] + elif motions.shape[1] in [75, 159, 168]: + gnum = 2 + trans = motions[:, -3::] + motions = motions[:, :-3] + elif motions.shape[1] in [76, 160, 169]: + gnum = motions[:, -1:][0] + trans = motions[:, -4:-1:] + motions = motions[:, :-4] + elif motions.shape[1] in [72 + num_shapes, 156 + num_shapes, 165 + num_shapes]: + betas = motions[:, -num_shapes::] + gnum = 2 + motions = motions[:, :-num_shapes] + elif motions.shape[1] in [73 + num_shapes, 157 + num_shapes, 166 + num_shapes]: + betas = motions[:, -num_shapes::] + gnum = motions[:, -num_shapes-1:-num_shapes:][0] + motions = motions[:, :-num_shapes-1] + elif motions.shape[1] in [75 + num_shapes, 159 + num_shapes, 168 + num_shapes]: + betas = motions[:, -num_shapes::] + gnum = 2 + trans = motions[:, -num_shapes-3:-num_shapes:] + motions = motions[:, :-num_shapes-3] + elif motions.shape[1] in [76 + num_shapes, 160 + num_shapes, 169 + num_shapes]: + betas = motions[:, -num_shapes::] + gnum = motions[:, -num_shapes-1:-num_shapes:][0] + trans = motions[:, -num_shapes-4:-num_shapes-1:] + motions = motions[:, :-num_shapes-4] + + if gnum == 0: + gender = "female" + elif gnum == 1: + gender = "male" + else: + gender = "neutral" + + return motions, trans, gender, betas + +def info2dict(pose, trans=None, betas=None, mode="smpl", device="cuda", index=-1): + if isinstance(pose, np.ndarray): + pose = torch.from_numpy(pose) + + if trans is not None and isinstance(trans, np.ndarray): + trans = torch.from_numpy(trans) + + if betas is not None and isinstance(betas, np.ndarray): + betas = torch.from_numpy(betas) + elif betas is None: + betas = torch.zeros([pose.shape[0], 10]) + + if index != -1: + pose = pose[index:index+1] + + if trans is not None: + trans = trans[index:index+1] + + betas = betas[index:index+1] + + if mode == "smplx": + inputs = { + "global_orient": pose[:, :3].float().to(device), + "body_pose": pose[:, 3:66].float().to(device), + "jaw_pose": pose[:, 66:69].float().to(device), + "leye_pose": pose[:, 69:72].float().to(device), + "reye_pose": pose[:, 72:75].float().to(device), + "left_hand_pose":pose[:, 75:120].float().to(device), + "right_hand_pose":pose[:, 120:].float().to(device), + } + elif mode == "smplh": + inputs = { + "global_orient": pose[:, :3].float().to(device), + "body_pose": pose[:, 3:66].float().to(device), + "left_hand_pose":pose[:, 66:111].float().to(device), + "right_hand_pose":pose[:, 111:].float().to(device), + } + elif mode == "smpl": + inputs = { + "global_orient": pose[:, :3].float().to(device), + "body_pose": pose[:, 3:].float().to(device), + } + + if trans is not None: + inputs["transl"] = trans[:, :].float().to(device) + else: + print("No Translation Information") + + inputs["betas"] = betas[:, :].float().to(device) + + return inputs \ No newline at end of file diff --git a/SMPLX/read_joints_from_pose.py b/SMPLX/read_joints_from_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..33dd195a64f245e6e9429d3a88a72f33733698b1 --- /dev/null +++ b/SMPLX/read_joints_from_pose.py @@ -0,0 +1,110 @@ +import torch +import numpy as np +from torch import nn +import pickle as pkl +import torch.nn.functional as F + +class Struct(object): + def __init__(self, **kwargs): + for key, val in kwargs.items(): + setattr(self, key, val) + + +def to_np(array, dtype=np.float32): + if 'scipy.sparse' in str(type(array)): + array = array.todense() + return np.array(array, dtype=dtype) + + +class Get_Joints(nn.Module): + def __init__(self, path, batch_size=300) -> None: + super().__init__() + self.betas = nn.parameter.Parameter(torch.zeros([batch_size, 10], dtype=torch.float32), requires_grad=False) + with open(path, "rb") as f: + smpl_prior = pkl.load(f, encoding="latin1") + data_struct = Struct(**smpl_prior) + + self.v_template = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.v_template)), requires_grad=False) + self.shapedirs = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.shapedirs)), requires_grad=False) + self.J_regressor = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.J_regressor)), requires_grad=False) + posedirs = torch.from_numpy(to_np(data_struct.posedirs)) + num_pose_basis = posedirs.shape[-1] + posedirs = posedirs.reshape([-1, num_pose_basis]).permute(1, 0) + self.posedirs = nn.parameter.Parameter(posedirs, requires_grad=False) + self.parents = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.kintree_table)[0]).long(), requires_grad=False) + self.parents[0] = -1 + + self.ident = nn.parameter.Parameter(torch.eye(3), requires_grad=False) + self.K = nn.parameter.Parameter(torch.zeros([1, 3, 3]), requires_grad=False) + self.zeros = nn.parameter.Parameter(torch.zeros([1, 1]), requires_grad=False) + + def blend_shapes(self, betas, shape_disps): + blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) + return blend_shape + + def vertices2joints(self, J_regressor, vertices): + return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) + + def batch_rodrigues( + self, + rot_vecs, + epsilon = 1e-8, + ): + batch_size = rot_vecs.shape[0] + angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = self.K.repeat(batch_size, 1, 1) + zeros = self.zeros.repeat(batch_size, 1) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3)) + ident = self.ident.unsqueeze(0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + def transform_mat(self, R, t): + return torch.cat([F.pad(R, [0, 0, 0, 1]), + F.pad(t, [0, 0, 0, 1], value=1)], dim=2) + + def batch_rigid_transform( + self, + rot_mats, + joints, + parents, + ): + joints = torch.unsqueeze(joints, dim=-1) + + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, parents[1:]] + + transforms_mat = self.transform_mat( + rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[parents[i]], + transforms_mat[:, i]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=1) + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + return posed_joints + + def forward(self, pose, trans=None): + pose = pose.float() + batch = pose.shape[0] + betas = self.betas[:batch] + v_shaped = self.v_template + self.blend_shapes(betas, self.shapedirs) + J = self.vertices2joints(self.J_regressor, v_shaped) + rot_mats = self.batch_rodrigues(pose.view(-1, 3)).view([batch, -1, 3, 3]) + J_transformed = self.batch_rigid_transform(rot_mats, J, self.parents) + if trans is not None: + J_transformed += trans.unsqueeze(dim=1) + return J_transformed \ No newline at end of file diff --git a/SMPLX/rotation_conversions.py b/SMPLX/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..1006e8a3117b231a7a456d5b826e76347fe0bfd4 --- /dev/null +++ b/SMPLX/rotation_conversions.py @@ -0,0 +1,532 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) +This matrix can be applied to column vectors by post multiplication +by the points e.g. + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + Returns: + batch of rotation matrices of size (*, 3, 3) + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + Returns: + 6D rotation representation, of size (*, 6) + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) + +def canonicalize_smplh(poses, trans = None): + bs, nframes, njoints = poses.shape[:3] + + global_orient = poses[:, :, 0] + + # first global rotations + rot2d = matrix_to_axis_angle(global_orient[:, 0]) + #rot2d[:, :2] = 0 # Remove the rotation along the vertical axis + rot2d = axis_angle_to_matrix(rot2d) + + # Rotate the global rotation to eliminate Z rotations + global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient) + + # Construct canonicalized version of x + xc = torch.cat((global_orient[:, :, None], poses[:, :, 1:]), dim=2) + + if trans is not None: + vel = trans[:, 1:] - trans[:, :-1] + # Turn the translation as well + vel = torch.einsum("ikj,ilk->ilj", rot2d, vel) + trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device), + torch.cumsum(vel, 1)), 1) + return xc, trans + else: + return xc + + \ No newline at end of file diff --git a/SMPLX/smplx/__init__.py b/SMPLX/smplx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..886949df670691d1ef5995737cafa285224826c4 --- /dev/null +++ b/SMPLX/smplx/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from .body_models import ( + create, + SMPL, + SMPLH, + SMPLX, + MANO, + FLAME, + build_layer, + SMPLLayer, + SMPLHLayer, + SMPLXLayer, + MANOLayer, + FLAMELayer, +) diff --git a/SMPLX/smplx/__pycache__/__init__.cpython-310.pyc b/SMPLX/smplx/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98d68d27786ff094987f33339ba83f90a9eb7d8c Binary files /dev/null and b/SMPLX/smplx/__pycache__/__init__.cpython-310.pyc differ diff --git a/SMPLX/smplx/__pycache__/__init__.cpython-311.pyc b/SMPLX/smplx/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83a38ac34ff6fc7659b66aa909a8343793108590 Binary files /dev/null and b/SMPLX/smplx/__pycache__/__init__.cpython-311.pyc differ diff --git a/SMPLX/smplx/__pycache__/__init__.cpython-39.pyc b/SMPLX/smplx/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dbeff4d0b240340352c12d56e374f063ff942da Binary files /dev/null and b/SMPLX/smplx/__pycache__/__init__.cpython-39.pyc differ diff --git a/SMPLX/smplx/__pycache__/body_models.cpython-310.pyc b/SMPLX/smplx/__pycache__/body_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed7d25f3a9421ed79e12e77358d67434d7320efc Binary files /dev/null and b/SMPLX/smplx/__pycache__/body_models.cpython-310.pyc differ diff --git a/SMPLX/smplx/__pycache__/body_models.cpython-311.pyc b/SMPLX/smplx/__pycache__/body_models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d61cdfed842a9159be6cdfd8c9b91e31c693916 Binary files /dev/null and b/SMPLX/smplx/__pycache__/body_models.cpython-311.pyc differ diff --git a/SMPLX/smplx/__pycache__/body_models.cpython-39.pyc b/SMPLX/smplx/__pycache__/body_models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad2f952eeb4c82db21439e5dc328c41a36b8e06d Binary files /dev/null and b/SMPLX/smplx/__pycache__/body_models.cpython-39.pyc differ diff --git a/SMPLX/smplx/__pycache__/joint_names.cpython-39.pyc b/SMPLX/smplx/__pycache__/joint_names.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3703f8bd86558fe6d5f16292b9134ec0cea9e8f Binary files /dev/null and b/SMPLX/smplx/__pycache__/joint_names.cpython-39.pyc differ diff --git a/SMPLX/smplx/__pycache__/lbs.cpython-310.pyc b/SMPLX/smplx/__pycache__/lbs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11fe3359b3a2f1ff841e2ada7662cb53ec3a3ff4 Binary files /dev/null and b/SMPLX/smplx/__pycache__/lbs.cpython-310.pyc differ diff --git a/SMPLX/smplx/__pycache__/lbs.cpython-311.pyc b/SMPLX/smplx/__pycache__/lbs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c322f797e2e3d05b5c6cc8494533b286446c872 Binary files /dev/null and b/SMPLX/smplx/__pycache__/lbs.cpython-311.pyc differ diff --git a/SMPLX/smplx/__pycache__/lbs.cpython-39.pyc b/SMPLX/smplx/__pycache__/lbs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59e9d19c9e72f769b22db8990a9fb3e23793de75 Binary files /dev/null and b/SMPLX/smplx/__pycache__/lbs.cpython-39.pyc differ diff --git a/SMPLX/smplx/__pycache__/utils.cpython-310.pyc b/SMPLX/smplx/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c9c7483121ececd6492ff34ee537f3bbdb96ea8 Binary files /dev/null and b/SMPLX/smplx/__pycache__/utils.cpython-310.pyc differ diff --git a/SMPLX/smplx/__pycache__/utils.cpython-311.pyc b/SMPLX/smplx/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f73f6493dc88285649d7d75b68d73adc4501069 Binary files /dev/null and b/SMPLX/smplx/__pycache__/utils.cpython-311.pyc differ diff --git a/SMPLX/smplx/__pycache__/utils.cpython-39.pyc b/SMPLX/smplx/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4815fb1005bd6d6917ddcb14c9f0e7736709cd8f Binary files /dev/null and b/SMPLX/smplx/__pycache__/utils.cpython-39.pyc differ diff --git a/SMPLX/smplx/__pycache__/vertex_ids.cpython-310.pyc b/SMPLX/smplx/__pycache__/vertex_ids.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9860d1c960adcec2b00ed78e8cb2333c092e0076 Binary files /dev/null and b/SMPLX/smplx/__pycache__/vertex_ids.cpython-310.pyc differ diff --git a/SMPLX/smplx/__pycache__/vertex_ids.cpython-311.pyc b/SMPLX/smplx/__pycache__/vertex_ids.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..254abbf8ba2a63480e8e0a9c4074587763c558d1 Binary files /dev/null and b/SMPLX/smplx/__pycache__/vertex_ids.cpython-311.pyc differ diff --git a/SMPLX/smplx/__pycache__/vertex_ids.cpython-39.pyc b/SMPLX/smplx/__pycache__/vertex_ids.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc16d7cb16fbdae3873c9b8c28ccfd069f2ed0fe Binary files /dev/null and b/SMPLX/smplx/__pycache__/vertex_ids.cpython-39.pyc differ diff --git a/SMPLX/smplx/__pycache__/vertex_joint_selector.cpython-310.pyc b/SMPLX/smplx/__pycache__/vertex_joint_selector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9974ffcbeaec12d1a4e283c02e63ebc4b3c76e6 Binary files /dev/null and b/SMPLX/smplx/__pycache__/vertex_joint_selector.cpython-310.pyc differ diff --git a/SMPLX/smplx/__pycache__/vertex_joint_selector.cpython-311.pyc b/SMPLX/smplx/__pycache__/vertex_joint_selector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91202b1354d134dce0babb749a271349f8503eba Binary files /dev/null and b/SMPLX/smplx/__pycache__/vertex_joint_selector.cpython-311.pyc differ diff --git a/SMPLX/smplx/__pycache__/vertex_joint_selector.cpython-39.pyc b/SMPLX/smplx/__pycache__/vertex_joint_selector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c1f29bb259f660bdb84d2a352d128541b70d9ca Binary files /dev/null and b/SMPLX/smplx/__pycache__/vertex_joint_selector.cpython-39.pyc differ diff --git a/SMPLX/smplx/body_models.py b/SMPLX/smplx/body_models.py new file mode 100644 index 0000000000000000000000000000000000000000..5f651ba45c11d8877b7114ea72e83e8f83d5434a --- /dev/null +++ b/SMPLX/smplx/body_models.py @@ -0,0 +1,2440 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional, Dict, Union +import os +import os.path as osp + +import pickle + +import numpy as np + +import torch +import torch.nn as nn + +from .lbs import ( + lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords, blend_shapes) + +from .vertex_ids import vertex_ids as VERTEX_IDS +from .utils import ( + Struct, to_np, to_tensor, Tensor, Array, + SMPLOutput, + SMPLHOutput, + SMPLXOutput, + MANOOutput, + FLAMEOutput, + find_joint_kin_chain) +from .vertex_joint_selector import VertexJointSelector +from collections import namedtuple + +TensorOutput = namedtuple('TensorOutput', + ['vertices', 'joints', 'betas', 'expression', 'global_orient', 'body_pose', 'left_hand_pose', + 'right_hand_pose', 'jaw_pose', 'transl', 'full_pose']) + + +class SMPL(nn.Module): + + NUM_JOINTS = 23 + NUM_BODY_JOINTS = 23 + SHAPE_SPACE_DIM = 300 + + def __init__( + self, model_path: str, + kid_template_path: str = '', + data_struct: Optional[Struct] = None, + create_betas: bool = True, + betas: Optional[Tensor] = None, + num_betas: int = 10, + create_global_orient: bool = True, + global_orient: Optional[Tensor] = None, + create_body_pose: bool = True, + body_pose: Optional[Tensor] = None, + create_transl: bool = True, + transl: Optional[Tensor] = None, + dtype=torch.float32, + batch_size: int = 1, + joint_mapper=None, + gender: str = 'neutral', + age: str = 'adult', + vertex_ids: Dict[str, int] = None, + v_template: Optional[Union[Tensor, Array]] = None, + **kwargs + ) -> None: + ''' SMPL model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + data_struct: Strct + A struct object. If given, then the parameters of the model are + read from the object. Otherwise, the model tries to read the + parameters from the given `model_path`. (default = None) + create_global_orient: bool, optional + Flag for creating a member variable for the global orientation + of the body. (default = True) + global_orient: torch.tensor, optional, Bx3 + The default value for the global orientation variable. + (default = None) + create_body_pose: bool, optional + Flag for creating a member variable for the pose of the body. + (default = True) + body_pose: torch.tensor, optional, Bx(Body Joints * 3) + The default value for the body pose variable. + (default = None) + num_betas: int, optional + Number of shape components to use + (default = 10). + create_betas: bool, optional + Flag for creating a member variable for the shape space + (default = True). + betas: torch.tensor, optional, Bx10 + The default value for the shape member variable. + (default = None) + create_transl: bool, optional + Flag for creating a member variable for the translation + of the body. (default = True) + transl: torch.tensor, optional, Bx3 + The default value for the transl variable. + (default = None) + dtype: torch.dtype, optional + The data type for the created variables + batch_size: int, optional + The batch size used for creating the member variables + joint_mapper: object, optional + An object that re-maps the joints. Useful if one wants to + re-order the SMPL joints to some other convention (e.g. MSCOCO) + (default = None) + gender: str, optional + Which gender to load + vertex_ids: dict, optional + A dictionary containing the indices of the extra vertices that + will be selected + ''' + + self.gender = gender + self.age = age + + if data_struct is None: + if osp.isdir(model_path): + model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl') + smpl_path = os.path.join(model_path, model_fn) + else: + smpl_path = model_path + assert osp.exists(smpl_path), 'Path {} does not exist!'.format( + smpl_path) + + with open(smpl_path, 'rb') as smpl_file: + data_struct = Struct(**pickle.load(smpl_file,encoding='latin1')) + + super(SMPL, self).__init__() + self.batch_size = batch_size + shapedirs = data_struct.shapedirs + if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM): + # print(f'WARNING: You are using a {self.name()} model, with only' + # f' {shapedirs.shape[-1]} shape coefficients.\n' + # f'num_betas={num_betas}, shapedirs.shape={shapedirs.shape}, ' + # f'self.SHAPE_SPACE_DIM={self.SHAPE_SPACE_DIM}') + num_betas = min(num_betas, shapedirs.shape[-1]) + else: + num_betas = min(num_betas, self.SHAPE_SPACE_DIM) + + if self.age == 'kid': + v_template_smil = np.load(kid_template_path) + v_template_smil -= np.mean(v_template_smil, axis=0) + v_template_diff = np.expand_dims( + v_template_smil - data_struct.v_template, axis=2) + shapedirs = np.concatenate( + (shapedirs[:, :, :num_betas], v_template_diff), axis=2) + num_betas = num_betas + 1 + + self._num_betas = num_betas + shapedirs = shapedirs[:, :, :num_betas] + # The shape components + self.register_buffer( + 'shapedirs', + to_tensor(to_np(shapedirs), dtype=dtype)) + + if vertex_ids is None: + # SMPL and SMPL-H share the same topology, so any extra joints can + # be drawn from the same place + vertex_ids = VERTEX_IDS['smplh'] + + self.dtype = dtype + + self.joint_mapper = joint_mapper + + self.vertex_joint_selector = VertexJointSelector( + vertex_ids=vertex_ids, **kwargs) + + self.faces = data_struct.f + self.register_buffer('faces_tensor', + to_tensor(to_np(self.faces, dtype=np.int64), + dtype=torch.long)) + + if create_betas: + if betas is None: + default_betas = torch.zeros( + [batch_size, self.num_betas], dtype=dtype) + else: + if torch.is_tensor(betas): + default_betas = betas.clone().detach() + else: + default_betas = torch.tensor(betas, dtype=dtype) + + self.register_parameter( + 'betas', nn.Parameter(default_betas, requires_grad=True)) + + # The tensor that contains the global rotation of the model + # It is separated from the pose of the joints in case we wish to + # optimize only over one of them + if create_global_orient: + if global_orient is None: + default_global_orient = torch.zeros( + [batch_size, 3], dtype=dtype) + else: + if torch.is_tensor(global_orient): + default_global_orient = global_orient.clone().detach() + else: + default_global_orient = torch.tensor( + global_orient, dtype=dtype) + + global_orient = nn.Parameter(default_global_orient, + requires_grad=True) + self.register_parameter('global_orient', global_orient) + + if create_body_pose: + if body_pose is None: + default_body_pose = torch.zeros( + [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) + else: + if torch.is_tensor(body_pose): + default_body_pose = body_pose.clone().detach() + else: + default_body_pose = torch.tensor(body_pose, + dtype=dtype) + self.register_parameter( + 'body_pose', + nn.Parameter(default_body_pose, requires_grad=True)) + + if create_transl: + if transl is None: + default_transl = torch.zeros([batch_size, 3], + dtype=dtype, + requires_grad=True) + else: + default_transl = torch.tensor(transl, dtype=dtype) + self.register_parameter( + 'transl', nn.Parameter(default_transl, requires_grad=True)) + + if v_template is None: + v_template = data_struct.v_template + if not torch.is_tensor(v_template): + v_template = to_tensor(to_np(v_template), dtype=dtype) + # The vertices of the template model + self.register_buffer('v_template', v_template) + + j_regressor = to_tensor(to_np( + data_struct.J_regressor), dtype=dtype) + self.register_buffer('J_regressor', j_regressor) + + # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207 + num_pose_basis = data_struct.posedirs.shape[-1] + # 207 x 20670 + posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T + self.register_buffer('posedirs', + to_tensor(to_np(posedirs), dtype=dtype)) + + # indices of parents for each joints + parents = to_tensor(to_np(data_struct.kintree_table[0])).long() + parents[0] = -1 + self.register_buffer('parents', parents) + + lbs_weights = to_tensor(to_np(data_struct.weights), dtype=dtype) + self.register_buffer('lbs_weights', lbs_weights) + + @property + def num_betas(self): + return self._num_betas + + @property + def num_expression_coeffs(self): + return 0 + + def create_mean_pose(self, data_struct) -> Tensor: + pass + + def name(self) -> str: + return 'SMPL' + + @torch.no_grad() + def reset_params(self, **params_dict) -> None: + for param_name, param in self.named_parameters(): + if param_name in params_dict: + param[:] = torch.tensor(params_dict[param_name]) + else: + param.fill_(0) + + def get_num_verts(self) -> int: + return self.v_template.shape[0] + + def get_num_faces(self) -> int: + return self.faces.shape[0] + + def extra_repr(self) -> str: + msg = [ + f'Gender: {self.gender.upper()}', + f'Number of joints: {self.J_regressor.shape[0]}', + f'Betas: {self.num_betas}', + ] + return '\n'.join(msg) + + def forward_shape( + self, + betas: Optional[Tensor] = None, + ) -> SMPLOutput: + betas = betas if betas is not None else self.betas + v_shaped = self.v_template + blend_shapes(betas, self.shapedirs) + return SMPLOutput(vertices=v_shaped, betas=betas, v_shaped=v_shaped) + + def forward( + self, + betas: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts=True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> SMPLOutput: + ''' Forward pass for the SMPL model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + body_pose: torch.tensor, optional, shape Bx(J*3) + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + axis-angle format. (default=None) + transl: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `transl` and use it + instead. For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + ''' + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + body_pose = body_pose if body_pose is not None else self.body_pose + betas = betas if betas is not None else self.betas + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None and hasattr(self, 'transl'): + transl = self.transl + + full_pose = torch.cat([global_orient, body_pose], dim=1) + + batch_size = max(betas.shape[0], global_orient.shape[0], + body_pose.shape[0]) + + if betas.shape[0] != batch_size: + num_repeats = int(batch_size / betas.shape[0]) + betas = betas.expand(num_repeats, -1) + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=pose2rot) + + joints = self.vertex_joint_selector(vertices, joints) + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLOutput(vertices=vertices if return_verts else None, + global_orient=global_orient, + body_pose=body_pose, + joints=joints, + betas=betas, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLLayer(SMPL): + def __init__( + self, + *args, + **kwargs + ) -> None: + # Just create a SMPL module without any member variables + super(SMPLLayer, self).__init__( + create_body_pose=False, + create_betas=False, + create_global_orient=False, + create_transl=False, + *args, + **kwargs, + ) + + def forward( + self, + betas: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts=True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> SMPLOutput: + ''' Forward pass for the SMPL model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + Global rotation of the body. Useful if someone wishes to + predicts this with an external model. It is expected to be in + rotation matrix format. (default=None) + betas: torch.tensor, optional, shape BxN_b + Shape parameters. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + body_pose: torch.tensor, optional, shape BxJx3x3 + Body pose. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + ''' + model_vars = [betas, global_orient, body_pose, transl] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand( + batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + full_pose = torch.cat( + [global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)], + dim=1) + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, + pose2rot=False) + + joints = self.vertex_joint_selector(vertices, joints) + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if transl is not None: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLOutput(vertices=vertices if return_verts else None, + global_orient=global_orient, + body_pose=body_pose, + joints=joints, + betas=betas, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLH(SMPL): + + # The hand joints are replaced by MANO + NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2 + NUM_HAND_JOINTS = 15 + NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + + def __init__( + self, model_path, + kid_template_path: str = '', + data_struct: Optional[Struct] = None, + create_left_hand_pose: bool = True, + left_hand_pose: Optional[Tensor] = None, + create_right_hand_pose: bool = True, + right_hand_pose: Optional[Tensor] = None, + use_pca: bool = True, + num_pca_comps: int = 6, + num_betas=16, + flat_hand_mean: bool = False, + batch_size: int = 1, + gender: str = 'neutral', + age: str = 'adult', + dtype=torch.float32, + vertex_ids=None, + use_compressed: bool = True, + ext: str = 'pkl', + **kwargs + ) -> None: + ''' SMPLH model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + data_struct: Strct + A struct object. If given, then the parameters of the model are + read from the object. Otherwise, the model tries to read the + parameters from the given `model_path`. (default = None) + create_left_hand_pose: bool, optional + Flag for creating a member variable for the pose of the left + hand. (default = True) + left_hand_pose: torch.tensor, optional, BxP + The default value for the left hand pose member variable. + (default = None) + create_right_hand_pose: bool, optional + Flag for creating a member variable for the pose of the right + hand. (default = True) + right_hand_pose: torch.tensor, optional, BxP + The default value for the right hand pose member variable. + (default = None) + num_pca_comps: int, optional + The number of PCA components to use for each hand. + (default = 6) + flat_hand_mean: bool, optional + If False, then the pose of the hand is initialized to False. + batch_size: int, optional + The batch size used for creating the member variables + gender: str, optional + Which gender to load + dtype: torch.dtype, optional + The data type for the created variables + vertex_ids: dict, optional + A dictionary containing the indices of the extra vertices that + will be selected + ''' + + self.num_pca_comps = num_pca_comps + # If no data structure is passed, then load the data from the given + # model folder + if data_struct is None: + # Load the model + if osp.isdir(model_path): + model_fn = 'SMPLH_{}.{ext}'.format(gender.upper(), ext=ext) + smplh_path = os.path.join(model_path, model_fn) + else: + smplh_path = model_path + assert osp.exists(smplh_path), 'Path {} does not exist!'.format( + smplh_path) + + if ext == 'pkl': + with open(smplh_path, 'rb') as smplh_file: + model_data = pickle.load(smplh_file, encoding='latin1') + elif ext == 'npz': + model_data = np.load(smplh_path, allow_pickle=True) + else: + raise ValueError('Unknown extension: {}'.format(ext)) + data_struct = Struct(**model_data) + + if vertex_ids is None: + vertex_ids = VERTEX_IDS['smplh'] + + super(SMPLH, self).__init__( + model_path=model_path, + kid_template_path=kid_template_path, + data_struct=data_struct, + num_betas=num_betas, + batch_size=batch_size, vertex_ids=vertex_ids, gender=gender, age=age, + use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs) + + self.use_pca = use_pca + self.num_pca_comps = num_pca_comps + self.flat_hand_mean = flat_hand_mean + + try: + left_hand_components = data_struct.hands_componentsl[:num_pca_comps] + right_hand_components = data_struct.hands_componentsr[:num_pca_comps] + except: + curr_path = model_path.replace("smplh", "") + handleft_dir = os.path.join(curr_path, "mano") + with open(os.path.join(handleft_dir, "MANO_LEFT.pkl"), 'rb') as smplh_file: + handleft = pickle.load(smplh_file, encoding='latin1') + with open(os.path.join(handleft_dir, "MANO_RIGHT.pkl"), 'rb') as smplh_file: + handright = pickle.load(smplh_file, encoding='latin1') + + setattr(data_struct, "hands_componentsl", handleft["hands_components"]) + setattr(data_struct, "hands_componentsr", handright["hands_components"]) + setattr(data_struct, "hands_meanl", handleft["hands_mean"]) + setattr(data_struct, "hands_meanr", handright["hands_mean"]) + + left_hand_components = data_struct.hands_componentsl[:num_pca_comps] + right_hand_components = data_struct.hands_componentsr[:num_pca_comps] + + self.np_left_hand_components = left_hand_components + self.np_right_hand_components = right_hand_components + if self.use_pca: + self.register_buffer( + 'left_hand_components', + torch.tensor(left_hand_components, dtype=dtype)) + self.register_buffer( + 'right_hand_components', + torch.tensor(right_hand_components, dtype=dtype)) + + if self.flat_hand_mean: + left_hand_mean = np.zeros_like(data_struct.hands_meanl) + else: + left_hand_mean = data_struct.hands_meanl + + if self.flat_hand_mean: + right_hand_mean = np.zeros_like(data_struct.hands_meanr) + else: + right_hand_mean = data_struct.hands_meanr + + self.register_buffer('left_hand_mean', + to_tensor(left_hand_mean, dtype=self.dtype)) + self.register_buffer('right_hand_mean', + to_tensor(right_hand_mean, dtype=self.dtype)) + + # Create the buffers for the pose of the left hand + hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS + if create_left_hand_pose: + if left_hand_pose is None: + default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], + dtype=dtype) + else: + default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype) + + left_hand_pose_param = nn.Parameter(default_lhand_pose, + requires_grad=True) + self.register_parameter('left_hand_pose', + left_hand_pose_param) + + if create_right_hand_pose: + if right_hand_pose is None: + default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], + dtype=dtype) + else: + default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype) + + right_hand_pose_param = nn.Parameter(default_rhand_pose, + requires_grad=True) + self.register_parameter('right_hand_pose', + right_hand_pose_param) + + # Create the buffer for the mean pose. + pose_mean_tensor = self.create_mean_pose( + data_struct, flat_hand_mean=flat_hand_mean) + if not torch.is_tensor(pose_mean_tensor): + pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype) + self.register_buffer('pose_mean', pose_mean_tensor) + + def create_mean_pose(self, data_struct, flat_hand_mean=False): + # Create the array for the mean pose. If flat_hand is false, then use + # the mean that is given by the data, rather than the flat open hand + global_orient_mean = torch.zeros([3], dtype=self.dtype) + body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], + dtype=self.dtype) + + pose_mean = torch.cat([global_orient_mean, body_pose_mean, + self.left_hand_mean, + self.right_hand_mean], dim=0) + return pose_mean + + def name(self) -> str: + return 'SMPL+H' + + def extra_repr(self): + msg = super(SMPLH, self).extra_repr() + msg = [msg] + if self.use_pca: + msg.append(f'Number of PCA components: {self.num_pca_comps}') + msg.append(f'Flat hand mean: {self.flat_hand_mean}') + return '\n'.join(msg) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> SMPLHOutput: + ''' + ''' + + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + body_pose = body_pose if body_pose is not None else self.body_pose + betas = betas if betas is not None else self.betas + left_hand_pose = (left_hand_pose if left_hand_pose is not None else + self.left_hand_pose) + right_hand_pose = (right_hand_pose if right_hand_pose is not None else + self.right_hand_pose) + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + if self.use_pca: + left_hand_pose = torch.einsum( + 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) + right_hand_pose = torch.einsum( + 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) + + full_pose = torch.cat([global_orient, body_pose, + left_hand_pose, + right_hand_pose], dim=1) + full_pose += self.pose_mean + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=pose2rot) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLHOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLHLayer(SMPLH): + + def __init__( + self, *args, **kwargs + ) -> None: + ''' SMPL+H as a layer model constructor + ''' + super(SMPLHLayer, self).__init__( + create_global_orient=False, + create_body_pose=False, + create_left_hand_pose=False, + create_right_hand_pose=False, + create_betas=False, + create_transl=False, + *args, + **kwargs) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> SMPLHOutput: + ''' Forward pass for the SMPL+H model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + Global rotation of the body. Useful if someone wishes to + predicts this with an external model. It is expected to be in + rotation matrix format. (default=None) + betas: torch.tensor, optional, shape BxN_b + Shape parameters. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + body_pose: torch.tensor, optional, shape BxJx3x3 + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + left_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the left hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + right_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the right hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + ''' + model_vars = [betas, global_orient, body_pose, transl, left_hand_pose, + right_hand_pose] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous() + if left_hand_pose is None: + left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if right_hand_pose is None: + right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + # Concatenate all pose vectors + full_pose = torch.cat( + [global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), + left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), + right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)], + dim=1) + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=False) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if transl is not None: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLHOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLX(SMPLH): + ''' + SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters + trained jointly for the face, hands and body. + SMPL-X uses standard vertex based linear blend skinning with learned + corrective blend shapes, has N=10475 vertices and K=54 joints, + which includes joints for the neck, jaw, eyeballs and fingers. + ''' + + NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS + NUM_HAND_JOINTS = 15 + NUM_FACE_JOINTS = 3 + NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS + EXPRESSION_SPACE_DIM = 100 + NECK_IDX = 12 + + def __init__( + self, model_path: str, + kid_template_path: str = '', + num_expression_coeffs: int = 10, + create_expression: bool = True, + expression: Optional[Tensor] = None, + create_jaw_pose: bool = True, + jaw_pose: Optional[Tensor] = None, + create_leye_pose: bool = True, + leye_pose: Optional[Tensor] = None, + create_reye_pose=True, + reye_pose: Optional[Tensor] = None, + use_face_contour: bool = False, + batch_size: int = 1, + gender: str = 'neutral', + age: str = 'adult', + dtype=torch.float32, + ext: str = 'npz', + **kwargs + ) -> None: + ''' SMPLX model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + num_expression_coeffs: int, optional + Number of expression components to use + (default = 10). + create_expression: bool, optional + Flag for creating a member variable for the expression space + (default = True). + expression: torch.tensor, optional, Bx10 + The default value for the expression member variable. + (default = None) + create_jaw_pose: bool, optional + Flag for creating a member variable for the jaw pose. + (default = False) + jaw_pose: torch.tensor, optional, Bx3 + The default value for the jaw pose variable. + (default = None) + create_leye_pose: bool, optional + Flag for creating a member variable for the left eye pose. + (default = False) + leye_pose: torch.tensor, optional, Bx10 + The default value for the left eye pose variable. + (default = None) + create_reye_pose: bool, optional + Flag for creating a member variable for the right eye pose. + (default = False) + reye_pose: torch.tensor, optional, Bx10 + The default value for the right eye pose variable. + (default = None) + use_face_contour: bool, optional + Whether to compute the keypoints that form the facial contour + batch_size: int, optional + The batch size used for creating the member variables + gender: str, optional + Which gender to load + dtype: torch.dtype + The data type for the created variables + ''' + + # Load the model + if osp.isdir(model_path): + model_fn = 'SMPLX_{}.{ext}'.format(gender.upper(), ext=ext) + smplx_path = os.path.join(model_path, model_fn) + else: + smplx_path = model_path + assert osp.exists(smplx_path), 'Path {} does not exist!'.format( + smplx_path) + + if ext == 'pkl': + with open(smplx_path, 'rb') as smplx_file: + model_data = pickle.load(smplx_file, encoding='latin1') + elif ext == 'npz': + model_data = np.load(smplx_path, allow_pickle=True) + else: + raise ValueError('Unknown extension: {}'.format(ext)) + + data_struct = Struct(**model_data) + + super(SMPLX, self).__init__( + model_path=model_path, + kid_template_path=kid_template_path, + data_struct=data_struct, + dtype=dtype, + batch_size=batch_size, + vertex_ids=VERTEX_IDS['smplx'], + gender=gender, age=age, ext=ext, + **kwargs) + + lmk_faces_idx = data_struct.lmk_faces_idx + self.register_buffer('lmk_faces_idx', + torch.tensor(lmk_faces_idx, dtype=torch.long)) + lmk_bary_coords = data_struct.lmk_bary_coords + self.register_buffer('lmk_bary_coords', + torch.tensor(lmk_bary_coords, dtype=dtype)) + + self.use_face_contour = use_face_contour + if self.use_face_contour: + dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx + dynamic_lmk_faces_idx = torch.tensor( + dynamic_lmk_faces_idx, + dtype=torch.long) + self.register_buffer('dynamic_lmk_faces_idx', + dynamic_lmk_faces_idx) + + dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords + dynamic_lmk_bary_coords = torch.tensor( + dynamic_lmk_bary_coords, dtype=dtype) + self.register_buffer('dynamic_lmk_bary_coords', + dynamic_lmk_bary_coords) + + neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) + self.register_buffer( + 'neck_kin_chain', + torch.tensor(neck_kin_chain, dtype=torch.long)) + + if create_jaw_pose: + if jaw_pose is None: + default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) + jaw_pose_param = nn.Parameter(default_jaw_pose, + requires_grad=True) + self.register_parameter('jaw_pose', jaw_pose_param) + + if create_leye_pose: + if leye_pose is None: + default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_leye_pose = torch.tensor(leye_pose, dtype=dtype) + leye_pose_param = nn.Parameter(default_leye_pose, + requires_grad=True) + self.register_parameter('leye_pose', leye_pose_param) + + if create_reye_pose: + if reye_pose is None: + default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_reye_pose = torch.tensor(reye_pose, dtype=dtype) + reye_pose_param = nn.Parameter(default_reye_pose, + requires_grad=True) + self.register_parameter('reye_pose', reye_pose_param) + + shapedirs = data_struct.shapedirs + if len(shapedirs.shape) < 3: + shapedirs = shapedirs[:, :, None] + if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + + self.EXPRESSION_SPACE_DIM): + # print(f'WARNING: You are using a {self.name()} model, with only' + # ' 10 shape and 10 expression coefficients.') + expr_start_idx = 10 + expr_end_idx = 20 + num_expression_coeffs = min(num_expression_coeffs, 10) + else: + expr_start_idx = self.SHAPE_SPACE_DIM + expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs + num_expression_coeffs = min( + num_expression_coeffs, self.EXPRESSION_SPACE_DIM) + + self._num_expression_coeffs = num_expression_coeffs + + expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] + self.register_buffer( + 'expr_dirs', to_tensor(to_np(expr_dirs), dtype=dtype)) + + if create_expression: + if expression is None: + default_expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype) + else: + default_expression = torch.tensor(expression, dtype=dtype) + expression_param = nn.Parameter(default_expression, + requires_grad=True) + self.register_parameter('expression', expression_param) + + def name(self) -> str: + return 'SMPL-X' + + @property + def num_expression_coeffs(self): + return self._num_expression_coeffs + + def create_mean_pose(self, data_struct, flat_hand_mean=False): + # Create the array for the mean pose. If flat_hand is false, then use + # the mean that is given by the data, rather than the flat open hand + global_orient_mean = torch.zeros([3], dtype=self.dtype) + body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], + dtype=self.dtype) + jaw_pose_mean = torch.zeros([3], dtype=self.dtype) + leye_pose_mean = torch.zeros([3], dtype=self.dtype) + reye_pose_mean = torch.zeros([3], dtype=self.dtype) + + pose_mean = np.concatenate([global_orient_mean, body_pose_mean, + jaw_pose_mean, + leye_pose_mean, reye_pose_mean, + self.left_hand_mean, self.right_hand_mean], + axis=0) + + return pose_mean + + def extra_repr(self): + msg = super(SMPLX, self).extra_repr() + msg = [ + msg, + f'Number of Expression Coefficients: {self.num_expression_coeffs}' + ] + return '\n'.join(msg) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + return_shaped: bool = True, + **kwargs + ) -> SMPLXOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + If given, ignore the member variable `expression` and use it + instead. For example, it can used if expression parameters + `expression` are predicted from some external model. + body_pose: torch.tensor, optional, shape Bx(J*3) + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + axis-angle format. (default=None) + left_hand_pose: torch.tensor, optional, shape BxP + If given, ignore the member variable `left_hand_pose` and + use this instead. It should either contain PCA coefficients or + joint rotations in axis-angle format. + right_hand_pose: torch.tensor, optional, shape BxP + If given, ignore the member variable `right_hand_pose` and + use this instead. It should either contain PCA coefficients or + joint rotations in axis-angle format. + jaw_pose: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `jaw_pose` and + use this instead. It should either joint rotations in + axis-angle format. + transl: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `transl` and use it + instead. For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + output: ModelOutput + A named tuple of type `ModelOutput` + ''' + + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + body_pose = body_pose if body_pose is not None else self.body_pose + betas = betas if betas is not None else self.betas + + left_hand_pose = (left_hand_pose if left_hand_pose is not None else + self.left_hand_pose) + right_hand_pose = (right_hand_pose if right_hand_pose is not None else + self.right_hand_pose) + jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose + leye_pose = leye_pose if leye_pose is not None else self.leye_pose + reye_pose = reye_pose if reye_pose is not None else self.reye_pose + expression = expression if expression is not None else self.expression + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + if self.use_pca: + left_hand_pose = torch.einsum( + 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) + right_hand_pose = torch.einsum( + 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) + + full_pose = torch.cat([global_orient.reshape(-1, 1, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3), + jaw_pose.reshape(-1, 1, 3), + leye_pose.reshape(-1, 1, 3), + reye_pose.reshape(-1, 1, 3), + left_hand_pose.reshape(-1, 15, 3), + right_hand_pose.reshape(-1, 15, 3)], + dim=1).reshape(-1, 165) + + # Add the mean pose of the model. Does not affect the body, only the + # hands when flat_hand_mean == False + full_pose += self.pose_mean + + batch_size = max(betas.shape[0], global_orient.shape[0], + body_pose.shape[0]) + # Concatenate the shape and expression coefficients + scale = int(batch_size / betas.shape[0]) + if scale > 1: + betas = betas.expand(scale, -1) + + shape_components = torch.cat([betas, expression], dim=-1) + + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs(shape_components, full_pose, self.v_template, + shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=pose2rot, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze( + dim=0).expand(batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + self.batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, full_pose, self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=True, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + + lmk_faces_idx = torch.cat([lmk_faces_idx, + dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), + dyn_lmk_bary_coords], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, + lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + # Map the joints to the current dataset + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + v_shaped = None + if return_shaped: + v_shaped = self.v_template + blend_shapes(betas, self.shapedirs) + else: + v_shaped = Tensor(0) + output = SMPLXOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + jaw_pose=jaw_pose, + v_shaped=v_shaped, + full_pose=full_pose if return_full_pose else None) + return output + + +class SMPLXLayer(SMPLX): + def __init__( + self, + *args, + **kwargs + ) -> None: + # Just create a SMPLX module without any member variables + super(SMPLXLayer, self).__init__( + create_global_orient=False, + create_body_pose=False, + create_left_hand_pose=False, + create_right_hand_pose=False, + create_jaw_pose=False, + create_leye_pose=False, + create_reye_pose=False, + create_betas=False, + create_expression=False, + create_transl=False, + *args, **kwargs, + ) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = True, + **kwargs + ) -> TensorOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. It is expected to be in rotation matrix + format. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + Expression coefficients. + For example, it can used if expression parameters + `expression` are predicted from some external model. + body_pose: torch.tensor, optional, shape BxJx3x3 + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + left_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the left hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + right_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the right hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + jaw_pose: torch.tensor, optional, shape Bx3x3 + Jaw pose. It should either joint rotations in + rotation matrix format. + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full pose vector (default=False) + Returns + ------- + output: ModelOutput + A data class that contains the posed vertices and joints + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + + model_vars = [betas, global_orient, body_pose, transl, + expression, left_hand_pose, right_hand_pose, jaw_pose] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + + if global_orient is None: + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand( + batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous() + if left_hand_pose is None: + left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if right_hand_pose is None: + right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if jaw_pose is None: + jaw_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if leye_pose is None: + leye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if reye_pose is None: + reye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if expression is None: + expression = torch.zeros([batch_size, self.num_expression_coeffs], + dtype=dtype, device=device) + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + # Concatenate all pose vectors + full_pose = torch.cat( + [global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), + jaw_pose.reshape(-1, 1, 3, 3), + leye_pose.reshape(-1, 1, 3, 3), + reye_pose.reshape(-1, 1, 3, 3), + left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), + right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)], + dim=1) + shape_components = torch.cat([betas, expression], dim=-1) + + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs(shape_components, full_pose, self.v_template, + shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, + pose2rot=False, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze( + dim=0).expand(batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, full_pose, + self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=False, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + + lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), + dyn_lmk_bary_coords], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, + lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + # Map the joints to the current dataset + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + if transl is not None: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = TensorOutput(vertices=vertices if return_verts else Tensor(0), + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + jaw_pose=jaw_pose, + transl=transl if transl != None else Tensor(0), + full_pose=full_pose if return_full_pose else Tensor(0)) + + return output + + +class MANO(SMPL): + # The hand joints are replaced by MANO + NUM_BODY_JOINTS = 1 + NUM_HAND_JOINTS = 15 + NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS + + def __init__( + self, + model_path: str, + is_rhand: bool = True, + data_struct: Optional[Struct] = None, + create_hand_pose: bool = True, + hand_pose: Optional[Tensor] = None, + use_pca: bool = True, + num_pca_comps: int = 6, + flat_hand_mean: bool = False, + batch_size: int = 1, + dtype=torch.float32, + vertex_ids=None, + use_compressed: bool = True, + ext: str = 'pkl', + **kwargs + ) -> None: + ''' MANO model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + data_struct: Strct + A struct object. If given, then the parameters of the model are + read from the object. Otherwise, the model tries to read the + parameters from the given `model_path`. (default = None) + create_hand_pose: bool, optional + Flag for creating a member variable for the pose of the right + hand. (default = True) + hand_pose: torch.tensor, optional, BxP + The default value for the right hand pose member variable. + (default = None) + num_pca_comps: int, optional + The number of PCA components to use for each hand. + (default = 6) + flat_hand_mean: bool, optional + If False, then the pose of the hand is initialized to False. + batch_size: int, optional + The batch size used for creating the member variables + dtype: torch.dtype, optional + The data type for the created variables + vertex_ids: dict, optional + A dictionary containing the indices of the extra vertices that + will be selected + ''' + + self.num_pca_comps = num_pca_comps + self.is_rhand = is_rhand + # If no data structure is passed, then load the data from the given + # model folder + if data_struct is None: + # Load the model + if osp.isdir(model_path): + model_fn = 'MANO_{}.{ext}'.format( + 'RIGHT' if is_rhand else 'LEFT', ext=ext) + mano_path = os.path.join(model_path, model_fn) + else: + mano_path = model_path + self.is_rhand = True if 'RIGHT' in os.path.basename( + model_path) else False + assert osp.exists(mano_path), 'Path {} does not exist!'.format( + mano_path) + + if ext == 'pkl': + with open(mano_path, 'rb') as mano_file: + model_data = pickle.load(mano_file, encoding='latin1') + elif ext == 'npz': + model_data = np.load(mano_path, allow_pickle=True) + else: + raise ValueError('Unknown extension: {}'.format(ext)) + data_struct = Struct(**model_data) + + if vertex_ids is None: + vertex_ids = VERTEX_IDS['smplh'] + + super(MANO, self).__init__( + model_path=model_path, data_struct=data_struct, + batch_size=batch_size, vertex_ids=vertex_ids, + use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs) + + # add only MANO tips to the extra joints + self.vertex_joint_selector.extra_joints_idxs = to_tensor( + list(VERTEX_IDS['mano'].values()), dtype=torch.long) + + self.use_pca = use_pca + self.num_pca_comps = num_pca_comps + if self.num_pca_comps == 45: + self.use_pca = False + self.flat_hand_mean = flat_hand_mean + + hand_components = data_struct.hands_components[:num_pca_comps] + + self.np_hand_components = hand_components + + if self.use_pca: + self.register_buffer( + 'hand_components', + torch.tensor(hand_components, dtype=dtype)) + + if self.flat_hand_mean: + hand_mean = np.zeros_like(data_struct.hands_mean) + else: + hand_mean = data_struct.hands_mean + + self.register_buffer('hand_mean', + to_tensor(hand_mean, dtype=self.dtype)) + + # Create the buffers for the pose of the left hand + hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS + if create_hand_pose: + if hand_pose is None: + default_hand_pose = torch.zeros([batch_size, hand_pose_dim], + dtype=dtype) + else: + default_hand_pose = torch.tensor(hand_pose, dtype=dtype) + + hand_pose_param = nn.Parameter(default_hand_pose, + requires_grad=True) + self.register_parameter('hand_pose', + hand_pose_param) + + # Create the buffer for the mean pose. + pose_mean = self.create_mean_pose( + data_struct, flat_hand_mean=flat_hand_mean) + pose_mean_tensor = pose_mean.clone().to(dtype) + # pose_mean_tensor = torch.tensor(pose_mean, dtype=dtype) + self.register_buffer('pose_mean', pose_mean_tensor) + + def name(self) -> str: + return 'MANO' + + def create_mean_pose(self, data_struct, flat_hand_mean=False): + # Create the array for the mean pose. If flat_hand is false, then use + # the mean that is given by the data, rather than the flat open hand + global_orient_mean = torch.zeros([3], dtype=self.dtype) + pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0) + return pose_mean + + def extra_repr(self): + msg = [super(MANO, self).extra_repr()] + if self.use_pca: + msg.append(f'Number of PCA components: {self.num_pca_comps}') + msg.append(f'Flat hand mean: {self.flat_hand_mean}') + return '\n'.join(msg) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs + ) -> MANOOutput: + ''' Forward pass for the MANO model + ''' + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + betas = betas if betas is not None else self.betas + hand_pose = (hand_pose if hand_pose is not None else + self.hand_pose) + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + if self.use_pca: + hand_pose = torch.einsum( + 'bi,ij->bj', [hand_pose, self.hand_components]) + + full_pose = torch.cat([global_orient, hand_pose], dim=1) + full_pose += self.pose_mean + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=True, + ) + + # # Add pre-selected extra joints that might be needed + # joints = self.vertex_joint_selector(vertices, joints) + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if apply_trans: + joints = joints + transl.unsqueeze(dim=1) + vertices = vertices + transl.unsqueeze(dim=1) + + output = MANOOutput(vertices=vertices if return_verts else None, + joints=joints if return_verts else None, + betas=betas, + global_orient=global_orient, + hand_pose=hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class MANOLayer(MANO): + def __init__(self, *args, **kwargs) -> None: + ''' MANO as a layer model constructor + ''' + super(MANOLayer, self).__init__( + create_global_orient=False, + create_hand_pose=False, + create_betas=False, + create_transl=False, + *args, **kwargs) + + def name(self) -> str: + return 'MANO' + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs + ) -> MANOOutput: + ''' Forward pass for the MANO model + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + batch_size = 1 + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + else: + batch_size = global_orient.shape[0] + if hand_pose is None: + hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if betas is None: + betas = torch.zeros( + [batch_size, self.num_betas], dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + full_pose = torch.cat([global_orient, hand_pose], dim=1) + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=False) + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if transl is not None: + joints = joints + transl.unsqueeze(dim=1) + vertices = vertices + transl.unsqueeze(dim=1) + + output = MANOOutput( + vertices=vertices if return_verts else None, + joints=joints if return_verts else None, + betas=betas, + global_orient=global_orient, + hand_pose=hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class FLAME(SMPL): + NUM_JOINTS = 5 + SHAPE_SPACE_DIM = 300 + EXPRESSION_SPACE_DIM = 100 + NECK_IDX = 0 + + def __init__( + self, + model_path: str, + data_struct=None, + num_expression_coeffs=10, + create_expression: bool = True, + expression: Optional[Tensor] = None, + create_neck_pose: bool = True, + neck_pose: Optional[Tensor] = None, + create_jaw_pose: bool = True, + jaw_pose: Optional[Tensor] = None, + create_leye_pose: bool = True, + leye_pose: Optional[Tensor] = None, + create_reye_pose=True, + reye_pose: Optional[Tensor] = None, + use_face_contour=False, + batch_size: int = 1, + gender: str = 'neutral', + dtype: torch.dtype = torch.float32, + ext='pkl', + **kwargs + ) -> None: + ''' FLAME model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + num_expression_coeffs: int, optional + Number of expression components to use + (default = 10). + create_expression: bool, optional + Flag for creating a member variable for the expression space + (default = True). + expression: torch.tensor, optional, Bx10 + The default value for the expression member variable. + (default = None) + create_neck_pose: bool, optional + Flag for creating a member variable for the neck pose. + (default = False) + neck_pose: torch.tensor, optional, Bx3 + The default value for the neck pose variable. + (default = None) + create_jaw_pose: bool, optional + Flag for creating a member variable for the jaw pose. + (default = False) + jaw_pose: torch.tensor, optional, Bx3 + The default value for the jaw pose variable. + (default = None) + create_leye_pose: bool, optional + Flag for creating a member variable for the left eye pose. + (default = False) + leye_pose: torch.tensor, optional, Bx10 + The default value for the left eye pose variable. + (default = None) + create_reye_pose: bool, optional + Flag for creating a member variable for the right eye pose. + (default = False) + reye_pose: torch.tensor, optional, Bx10 + The default value for the right eye pose variable. + (default = None) + use_face_contour: bool, optional + Whether to compute the keypoints that form the facial contour + batch_size: int, optional + The batch size used for creating the member variables + gender: str, optional + Which gender to load + dtype: torch.dtype + The data type for the created variables + ''' + model_fn = f'FLAME_{gender.upper()}.{ext}' + flame_path = os.path.join(model_path, model_fn) + assert osp.exists(flame_path), 'Path {} does not exist!'.format( + flame_path) + if ext == 'npz': + file_data = np.load(flame_path, allow_pickle=True) + elif ext == 'pkl': + with open(flame_path, 'rb') as smpl_file: + file_data = pickle.load(smpl_file, encoding='latin1') + else: + raise ValueError('Unknown extension: {}'.format(ext)) + data_struct = Struct(**file_data) + + super(FLAME, self).__init__( + model_path=model_path, + data_struct=data_struct, + dtype=dtype, + batch_size=batch_size, + gender=gender, + ext=ext, + **kwargs) + + self.use_face_contour = use_face_contour + + self.vertex_joint_selector.extra_joints_idxs = to_tensor( + [], dtype=torch.long) + + if create_neck_pose: + if neck_pose is None: + default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_neck_pose = torch.tensor(neck_pose, dtype=dtype) + neck_pose_param = nn.Parameter( + default_neck_pose, requires_grad=True) + self.register_parameter('neck_pose', neck_pose_param) + + if create_jaw_pose: + if jaw_pose is None: + default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) + jaw_pose_param = nn.Parameter(default_jaw_pose, + requires_grad=True) + self.register_parameter('jaw_pose', jaw_pose_param) + + if create_leye_pose: + if leye_pose is None: + default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_leye_pose = torch.tensor(leye_pose, dtype=dtype) + leye_pose_param = nn.Parameter(default_leye_pose, + requires_grad=True) + self.register_parameter('leye_pose', leye_pose_param) + + if create_reye_pose: + if reye_pose is None: + default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_reye_pose = torch.tensor(reye_pose, dtype=dtype) + reye_pose_param = nn.Parameter(default_reye_pose, + requires_grad=True) + self.register_parameter('reye_pose', reye_pose_param) + + shapedirs = data_struct.shapedirs + if len(shapedirs.shape) < 3: + shapedirs = shapedirs[:, :, None] + if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + + self.EXPRESSION_SPACE_DIM): + # print(f'WARNING: You are using a {self.name()} model, with only' + # ' 10 shape and 10 expression coefficients.') + expr_start_idx = 10 + expr_end_idx = 20 + num_expression_coeffs = min(num_expression_coeffs, 10) + else: + expr_start_idx = self.SHAPE_SPACE_DIM + expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs + num_expression_coeffs = min( + num_expression_coeffs, self.EXPRESSION_SPACE_DIM) + + self._num_expression_coeffs = num_expression_coeffs + + expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] + self.register_buffer( + 'expr_dirs', to_tensor(to_np(expr_dirs), dtype=dtype)) + + if create_expression: + if expression is None: + default_expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype) + else: + default_expression = torch.tensor(expression, dtype=dtype) + expression_param = nn.Parameter(default_expression, + requires_grad=True) + self.register_parameter('expression', expression_param) + + # The pickle file that contains the barycentric coordinates for + # regressing the landmarks + landmark_bcoord_filename = osp.join( + model_path, 'flame_static_embedding.pkl') + + with open(landmark_bcoord_filename, 'rb') as fp: + landmarks_data = pickle.load(fp, encoding='latin1') + + lmk_faces_idx = landmarks_data['lmk_face_idx'].astype(np.int64) + self.register_buffer('lmk_faces_idx', + torch.tensor(lmk_faces_idx, dtype=torch.long)) + lmk_bary_coords = landmarks_data['lmk_b_coords'] + self.register_buffer('lmk_bary_coords', + torch.tensor(lmk_bary_coords, dtype=dtype)) + if self.use_face_contour: + face_contour_path = os.path.join( + model_path, 'flame_dynamic_embedding.npy') + contour_embeddings = np.load(face_contour_path, + allow_pickle=True, + encoding='latin1')[()] + + dynamic_lmk_faces_idx = np.array( + contour_embeddings['lmk_face_idx'], dtype=np.int64) + dynamic_lmk_faces_idx = torch.tensor( + dynamic_lmk_faces_idx, + dtype=torch.long) + self.register_buffer('dynamic_lmk_faces_idx', + dynamic_lmk_faces_idx) + + dynamic_lmk_b_coords = torch.tensor( + contour_embeddings['lmk_b_coords'], dtype=dtype) + self.register_buffer( + 'dynamic_lmk_bary_coords', dynamic_lmk_b_coords) + + neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) + self.register_buffer( + 'neck_kin_chain', + torch.tensor(neck_kin_chain, dtype=torch.long)) + + @property + def num_expression_coeffs(self): + return self._num_expression_coeffs + + def name(self) -> str: + return 'FLAME' + + def extra_repr(self): + msg = [ + super(FLAME, self).extra_repr(), + f'Number of Expression Coefficients: {self.num_expression_coeffs}', + f'Use face contour: {self.use_face_contour}', + ] + return '\n'.join(msg) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + neck_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> FLAMEOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. (default=None) + betas: torch.tensor, optional, shape Bx10 + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape Bx10 + If given, ignore the member variable `expression` and use it + instead. For example, it can used if expression parameters + `expression` are predicted from some external model. + jaw_pose: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `jaw_pose` and + use this instead. It should either joint rotations in + axis-angle format. + jaw_pose: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `jaw_pose` and + use this instead. It should either joint rotations in + axis-angle format. + transl: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `transl` and use it + instead. For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + output: ModelOutput + A named tuple of type `ModelOutput` + ''' + + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose + neck_pose = neck_pose if neck_pose is not None else self.neck_pose + + leye_pose = leye_pose if leye_pose is not None else self.leye_pose + reye_pose = reye_pose if reye_pose is not None else self.reye_pose + + betas = betas if betas is not None else self.betas + expression = expression if expression is not None else self.expression + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + full_pose = torch.cat( + [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) + + batch_size = max(betas.shape[0], global_orient.shape[0], + jaw_pose.shape[0]) + # Concatenate the shape and expression coefficients + scale = int(batch_size / betas.shape[0]) + if scale > 1: + betas = betas.expand(scale, -1) + shape_components = torch.cat([betas, expression], dim=-1) + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs(shape_components, full_pose, self.v_template, + shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=pose2rot, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze( + dim=0).expand(batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, full_pose, self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=True, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + lmk_faces_idx = torch.cat([lmk_faces_idx, + dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), + dyn_lmk_bary_coords], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, + lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = FLAMEOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + neck_pose=neck_pose, + jaw_pose=jaw_pose, + full_pose=full_pose if return_full_pose else None) + return output + + +class FLAMELayer(FLAME): + def __init__(self, *args, **kwargs) -> None: + ''' FLAME as a layer model constructor ''' + super(FLAMELayer, self).__init__( + create_betas=False, + create_expression=False, + create_global_orient=False, + create_neck_pose=False, + create_jaw_pose=False, + create_leye_pose=False, + create_reye_pose=False, + *args, + **kwargs) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + neck_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> FLAMEOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + Global rotation of the body. Useful if someone wishes to + predicts this with an external model. It is expected to be in + rotation matrix format. (default=None) + betas: torch.tensor, optional, shape BxN_b + Shape parameters. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + If given, ignore the member variable `expression` and use it + instead. For example, it can used if expression parameters + `expression` are predicted from some external model. + jaw_pose: torch.tensor, optional, shape Bx3x3 + Jaw pose. It should either joint rotations in + rotation matrix format. + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + output: ModelOutput + A named tuple of type `ModelOutput` + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + batch_size = 1 + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + else: + batch_size = global_orient.shape[0] + if neck_pose is None: + neck_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous() + if jaw_pose is None: + jaw_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if leye_pose is None: + leye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if reye_pose is None: + reye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, device=device) + if expression is None: + expression = torch.zeros([batch_size, self.num_expression_coeffs], + dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + full_pose = torch.cat( + [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) + + shape_components = torch.cat([betas, expression], dim=-1) + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs(shape_components, full_pose, self.v_template, + shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=False, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze( + dim=0).expand(batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, full_pose, self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=False, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + lmk_faces_idx = torch.cat([lmk_faces_idx, + dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), + dyn_lmk_bary_coords], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, + lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = FLAMEOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + neck_pose=neck_pose, + jaw_pose=jaw_pose, + full_pose=full_pose if return_full_pose else None) + return output + + +def build_layer( + model_path: str, + model_type: str = 'smpl', + **kwargs +) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]: + ''' Method for creating a model from a path and a model type + + Parameters + ---------- + model_path: str + Either the path to the model you wish to load or a folder, + where each subfolder contains the differents types, i.e.: + model_path: + | + |-- smpl + |-- SMPL_FEMALE + |-- SMPL_NEUTRAL + |-- SMPL_MALE + |-- smplh + |-- SMPLH_FEMALE + |-- SMPLH_MALE + |-- smplx + |-- SMPLX_FEMALE + |-- SMPLX_NEUTRAL + |-- SMPLX_MALE + |-- mano + |-- MANO RIGHT + |-- MANO LEFT + |-- flame + |-- FLAME_FEMALE + |-- FLAME_MALE + |-- FLAME_NEUTRAL + + model_type: str, optional + When model_path is a folder, then this parameter specifies the + type of model to be loaded + **kwargs: dict + Keyword arguments + + Returns + ------- + body_model: nn.Module + The PyTorch module that implements the corresponding body model + Raises + ------ + ValueError: In case the model type is not one of SMPL, SMPLH, + SMPLX, MANO or FLAME + ''' + + if osp.isdir(model_path): + model_path = os.path.join(model_path, model_type) + else: + model_type = osp.basename(model_path).split('_')[0].lower() + + if model_type.lower() == 'smpl': + return SMPLLayer(model_path, **kwargs) + elif model_type.lower() == 'smplh': + return SMPLHLayer(model_path, **kwargs) + elif model_type.lower() == 'smplx': + return SMPLXLayer(model_path, **kwargs) + elif 'mano' in model_type.lower(): + return MANOLayer(model_path, **kwargs) + elif 'flame' in model_type.lower(): + return FLAMELayer(model_path, **kwargs) + else: + raise ValueError(f'Unknown model type {model_type}, exiting!') + + +def create( + model_path: str, + model_type: str = 'smpl', + **kwargs +) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: + ''' Method for creating a model from a path and a model type + + Parameters + ---------- + model_path: str + Either the path to the model you wish to load or a folder, + where each subfolder contains the differents types, i.e.: + model_path: + | + |-- smpl + |-- SMPL_FEMALE + |-- SMPL_NEUTRAL + |-- SMPL_MALE + |-- smplh + |-- SMPLH_FEMALE + |-- SMPLH_MALE + |-- smplx + |-- SMPLX_FEMALE + |-- SMPLX_NEUTRAL + |-- SMPLX_MALE + |-- mano + |-- MANO RIGHT + |-- MANO LEFT + + model_type: str, optional + When model_path is a folder, then this parameter specifies the + type of model to be loaded + **kwargs: dict + Keyword arguments + + Returns + ------- + body_model: nn.Module + The PyTorch module that implements the corresponding body model + Raises + ------ + ValueError: In case the model type is not one of SMPL, SMPLH, + SMPLX, MANO or FLAME + ''' + + # If it's a folder, assume + if osp.isdir(model_path): + model_path = os.path.join(model_path, model_type) + else: + model_type = osp.basename(model_path).split('_')[0].lower() + + if model_type.lower() == 'smpl': + return SMPL(model_path, **kwargs) + elif model_type.lower() == 'smplh': + return SMPLH(model_path, **kwargs) + elif model_type.lower() == 'smplx': + return SMPLX(model_path, **kwargs) + elif 'mano' in model_type.lower(): + return MANO(model_path, **kwargs) + elif 'flame' in model_type.lower(): + return FLAME(model_path, **kwargs) + else: + raise ValueError(f'Unknown model type {model_type}, exiting!') diff --git a/SMPLX/smplx/joint_names.py b/SMPLX/smplx/joint_names.py new file mode 100644 index 0000000000000000000000000000000000000000..fcdb2e131850592c1c5ff7bb9eaf36841d21bcd6 --- /dev/null +++ b/SMPLX/smplx/joint_names.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import numpy as np + +JOINT_NAMES = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "jaw", + "left_eye_smplhf", + "right_eye_smplhf", + "left_index1", + "left_index2", + "left_index3", + "left_middle1", + "left_middle2", + "left_middle3", + "left_pinky1", + "left_pinky2", + "left_pinky3", + "left_ring1", + "left_ring2", + "left_ring3", + "left_thumb1", + "left_thumb2", + "left_thumb3", + "right_index1", + "right_index2", + "right_index3", + "right_middle1", + "right_middle2", + "right_middle3", + "right_pinky1", + "right_pinky2", + "right_pinky3", + "right_ring1", + "right_ring2", + "right_ring3", + "right_thumb1", + "right_thumb2", + "right_thumb3", + "nose", + "right_eye", + "left_eye", + "right_ear", + "left_ear", + "left_big_toe", + "left_small_toe", + "left_heel", + "right_big_toe", + "right_small_toe", + "right_heel", + "left_thumb", + "left_index", + "left_middle", + "left_ring", + "left_pinky", + "right_thumb", + "right_index", + "right_middle", + "right_ring", + "right_pinky", + "right_eye_brow1", + "right_eye_brow2", + "right_eye_brow3", + "right_eye_brow4", + "right_eye_brow5", + "left_eye_brow5", + "left_eye_brow4", + "left_eye_brow3", + "left_eye_brow2", + "left_eye_brow1", + "nose1", + "nose2", + "nose3", + "nose4", + "right_nose_2", + "right_nose_1", + "nose_middle", + "left_nose_1", + "left_nose_2", + "right_eye1", + "right_eye2", + "right_eye3", + "right_eye4", + "right_eye5", + "right_eye6", + "left_eye4", + "left_eye3", + "left_eye2", + "left_eye1", + "left_eye6", + "left_eye5", + "right_mouth_1", + "right_mouth_2", + "right_mouth_3", + "mouth_top", + "left_mouth_3", + "left_mouth_2", + "left_mouth_1", + "left_mouth_5", # 59 in OpenPose output + "left_mouth_4", # 58 in OpenPose output + "mouth_bottom", + "right_mouth_4", + "right_mouth_5", + "right_lip_1", + "right_lip_2", + "lip_top", + "left_lip_2", + "left_lip_1", + "left_lip_3", + "lip_bottom", + "right_lip_3", + # Face contour + "right_contour_1", + "right_contour_2", + "right_contour_3", + "right_contour_4", + "right_contour_5", + "right_contour_6", + "right_contour_7", + "right_contour_8", + "contour_middle", + "left_contour_8", + "left_contour_7", + "left_contour_6", + "left_contour_5", + "left_contour_4", + "left_contour_3", + "left_contour_2", + "left_contour_1", +] + + +SMPLH_JOINT_NAMES = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_index1", + "left_index2", + "left_index3", + "left_middle1", + "left_middle2", + "left_middle3", + "left_pinky1", + "left_pinky2", + "left_pinky3", + "left_ring1", + "left_ring2", + "left_ring3", + "left_thumb1", + "left_thumb2", + "left_thumb3", + "right_index1", + "right_index2", + "right_index3", + "right_middle1", + "right_middle2", + "right_middle3", + "right_pinky1", + "right_pinky2", + "right_pinky3", + "right_ring1", + "right_ring2", + "right_ring3", + "right_thumb1", + "right_thumb2", + "right_thumb3", + "nose", + "right_eye", + "left_eye", + "right_ear", + "left_ear", + "left_big_toe", + "left_small_toe", + "left_heel", + "right_big_toe", + "right_small_toe", + "right_heel", + "left_thumb", + "left_index", + "left_middle", + "left_ring", + "left_pinky", + "right_thumb", + "right_index", + "right_middle", + "right_ring", + "right_pinky", +] + +SMPL_JOINT_NAMES = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hand", + "right_hand", +] + + +class Body: + """ + Class for storing a single body pose. + """ + + def __init__(self, joints, joint_names): + assert joints.ndim > 1 + assert joints.shape[0] == len(joint_names) + self.joints = {} + for i, j in enumerate(joint_names): + self.joints[j] = joints[i] + + @staticmethod + def from_smpl(joints): + """ + Create a Body object from SMPL joints. + """ + return Body(joints, SMPL_JOINT_NAMES) + + @staticmethod + def from_smplh(joints): + """ + Create a Body object from SMPLH joints. + """ + return Body(joints, SMPLH_JOINT_NAMES) + + def _as(self, joint_names): + """ + Return a Body object with the specified joint names. + """ + joint_list = [] + for j in joint_names: + if j not in self.joints: + joint_list.append(np.zeros_like(self.joints["spine1"])) + else: + joint_list.append(self.joints[j]) + return np.stack(joint_list, axis=0) + + def as_smpl(self): + """ + Convert the body to SMPL joints. + """ + return self._as(SMPL_JOINT_NAMES) + + def as_smplh(self): + """ + Convert the body to SMPLH joints. + """ + return self._as(SMPLH_JOINT_NAMES) diff --git a/SMPLX/smplx/lbs.py b/SMPLX/smplx/lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..cbff706620d151d9441ccf8bc6ccadfb7b1e7410 --- /dev/null +++ b/SMPLX/smplx/lbs.py @@ -0,0 +1,405 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from typing import Tuple, List +import numpy as np + +import torch +import torch.nn.functional as F + +from .utils import rot_mat_to_euler, Tensor + + +def find_dynamic_lmk_idx_and_bcoords( + vertices: Tensor, + pose: Tensor, + dynamic_lmk_faces_idx: Tensor, + dynamic_lmk_b_coords: Tensor, + neck_kin_chain: List[int], + pose2rot: bool = True, +) -> Tuple[Tensor, Tensor]: + ''' Compute the faces, barycentric coordinates for the dynamic landmarks + + + To do so, we first compute the rotation of the neck around the y-axis + and then use a pre-computed look-up table to find the faces and the + barycentric coordinates that will be used. + + Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) + for providing the original TensorFlow implementation and for the LUT. + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + pose: torch.tensor Bx(Jx3), dtype = torch.float32 + The current pose of the body model + dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long + The look-up table from neck rotation to faces + dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 + The look-up table from neck rotation to barycentric coordinates + neck_kin_chain: list + A python list that contains the indices of the joints that form the + kinematic chain of the neck. + dtype: torch.dtype, optional + + Returns + ------- + dyn_lmk_faces_idx: torch.tensor, dtype = torch.long + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + ''' + + dtype = vertices.dtype + batch_size = vertices.shape[0] + + if pose2rot: + aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, + neck_kin_chain) + rot_mats = batch_rodrigues( + aa_pose.view(-1, 3)).view(batch_size, -1, 3, 3) + else: + rot_mats = torch.index_select( + pose.view(batch_size, -1, 3, 3), 1, neck_kin_chain) + + rel_rot_mat = torch.eye( + 3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0).repeat( + batch_size, 1, 1) + for idx in range(len(neck_kin_chain)): + rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) + + y_rot_angle = torch.round( + torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, + max=39)).to(dtype=torch.long) + neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) + mask = y_rot_angle.lt(-39).to(dtype=torch.long) + neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) + y_rot_angle = (neg_mask * neg_vals + + (1 - neg_mask) * y_rot_angle) + + dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, + 0, y_rot_angle) + dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, + 0, y_rot_angle) + + return dyn_lmk_faces_idx, dyn_lmk_b_coords + + +def vertices2landmarks( + vertices: Tensor, + faces: Tensor, + lmk_faces_idx: Tensor, + lmk_bary_coords: Tensor +) -> Tensor: + ''' Calculates landmarks by barycentric interpolation + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + faces: torch.tensor Fx3, dtype = torch.long + The faces of the mesh + lmk_faces_idx: torch.tensor L, dtype = torch.long + The tensor with the indices of the faces used to calculate the + landmarks. + lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 + The tensor of barycentric coordinates that are used to interpolate + the landmarks + + Returns + ------- + landmarks: torch.tensor BxLx3, dtype = torch.float32 + The coordinates of the landmarks for each mesh in the batch + ''' + # Extract the indices of the vertices for each face + # BxLx3 + batch_size, num_verts = vertices.shape[:2] + device = vertices.device + + lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1).to(torch.long)).view( + batch_size, -1, 3) + #The '.to(torch.long)'. + # added to make the trace work in c++, + # otherwise you get a runtime error in c++: + # 'index_select(): Expected dtype int32 or int64 for index' + + lmk_faces += torch.arange( + batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts + + lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( + batch_size, -1, 3, 3) + + landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) + return landmarks + + +def lbs( + betas: Tensor, + pose: Tensor, + v_template: Tensor, + shapedirs: Tensor, + posedirs: Tensor, + J_regressor: Tensor, + parents: Tensor, + lbs_weights: Tensor, + pose2rot: bool = True, +) -> Tuple[Tensor, Tensor]: + ''' Performs Linear Blend Skinning with the given shape and pose parameters + + Parameters + ---------- + betas : torch.tensor BxNB + The tensor of shape parameters + pose : torch.tensor Bx(J + 1) * 3 + The pose parameters in axis-angle format + v_template torch.tensor BxVx3 + The template mesh that will be deformed + shapedirs : torch.tensor 1xNB + The tensor of PCA shape displacements + posedirs : torch.tensor Px(V * 3) + The pose PCA coefficients + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from + the position of the vertices + parents: torch.tensor J + The array that describes the kinematic tree for the model + lbs_weights: torch.tensor N x V x (J + 1) + The linear blend skinning weights that represent how much the + rotation matrix of each part affects each vertex + pose2rot: bool, optional + Flag on whether to convert the input pose tensor to rotation + matrices. The default value is True. If False, then the pose tensor + should already contain rotation matrices and have a size of + Bx(J + 1)x9 + dtype: torch.dtype, optional + + Returns + ------- + verts: torch.tensor BxVx3 + The vertices of the mesh after applying the shape and pose + displacements. + joints: torch.tensor BxJx3 + The joints of the model + ''' + + batch_size = max(betas.shape[0], pose.shape[0]) + device, dtype = betas.device, betas.dtype + + # Add shape contribution + v_shaped = v_template + blend_shapes(betas, shapedirs) + + # Get the joints + # NxJx3 array + J = vertices2joints(J_regressor, v_shaped) + + # 3. Add pose blend shapes + # N x J x 3 x 3 + ident = torch.eye(3, dtype=dtype, device=device) + if pose2rot: + rot_mats = batch_rodrigues(pose.view(-1, 3)).view( + [batch_size, -1, 3, 3]) + + pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) + # (N x P) x (P, V * 3) -> N x V x 3 + pose_offsets = torch.matmul( + pose_feature, posedirs).view(batch_size, -1, 3) + else: + pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident + rot_mats = pose.view(batch_size, -1, 3, 3) + + pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), + posedirs).view(batch_size, -1, 3) + + v_posed = pose_offsets + v_shaped + # 4. Get the global joint location + J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) + + # 5. Do skinning: + # W is N x V x (J + 1) + W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) + # (N x V x (J + 1)) x (N x (J + 1) x 16) + num_joints = J_regressor.shape[0] + T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ + .view(batch_size, -1, 4, 4) + + homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], + dtype=dtype, device=device) + v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) + v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) + + verts = v_homo[:, :, :3, 0] + + return verts, J_transformed + + +def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor: + ''' Calculates the 3D joint locations from the vertices + + Parameters + ---------- + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from the + position of the vertices + vertices : torch.tensor BxVx3 + The tensor of mesh vertices + + Returns + ------- + torch.tensor BxJx3 + The location of the joints + ''' + + return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) + + +def blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor: + ''' Calculates the per vertex displacement due to the blend shapes + + + Parameters + ---------- + betas : torch.tensor Bx(num_betas) + Blend shape coefficients + shape_disps: torch.tensor Vx3x(num_betas) + Blend shapes + + Returns + ------- + torch.tensor BxVx3 + The per-vertex displacement due to shape deformation + ''' + + # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] + # i.e. Multiply each shape displacement by its corresponding beta and + # then sum them. + blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) + return blend_shape + + +def batch_rodrigues( + rot_vecs: Tensor, + epsilon: float = 1e-8, +) -> Tensor: + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + + batch_size = rot_vecs.shape[0] + device, dtype = rot_vecs.device, rot_vecs.dtype + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def transform_mat(R: Tensor, t: Tensor) -> Tensor: + ''' Creates a batch of transformation matrices + Args: + - R: Bx3x3 array of a batch of rotation matrices + - t: Bx3x1 array of a batch of translation vectors + Returns: + - T: Bx4x4 Transformation matrix + ''' + # No padding left or right, only add an extra row + return torch.cat([F.pad(R, [0, 0, 0, 1]), + F.pad(t, [0, 0, 0, 1], value=1)], dim=2) + + +def batch_rigid_transform( + rot_mats: Tensor, + joints: Tensor, + parents: Tensor, + dtype=torch.float32 +) -> Tensor: + """ + Applies a batch of rigid transformations to the joints + + Parameters + ---------- + rot_mats : torch.tensor BxNx3x3 + Tensor of rotation matrices + joints : torch.tensor BxNx3 + Locations of joints + parents : torch.tensor BxN + The kinematic tree of each object + dtype : torch.dtype, optional: + The data type of the created tensors, the default is torch.float32 + + Returns + ------- + posed_joints : torch.tensor BxNx3 + The locations of the joints after applying the pose rotations + rel_transforms : torch.tensor BxNx4x4 + The relative (with respect to the root joint) rigid transformations + for all the joints + """ + + joints = torch.unsqueeze(joints, dim=-1) + + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, parents[1:]] + + transforms_mat = transform_mat( + rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[parents[i]], + transforms_mat[:, i]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=1) + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + joints_homogen = F.pad(joints, [0, 0, 0, 1]) + + rel_transforms = transforms - F.pad( + torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) + + return posed_joints, rel_transforms diff --git a/SMPLX/smplx/utils.py b/SMPLX/smplx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a014698ddecdbbdd4cfa04119ec29ffb2211e6e2 --- /dev/null +++ b/SMPLX/smplx/utils.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import NewType, Union, Optional +from dataclasses import dataclass, asdict, fields +import numpy as np +import torch + +Tensor = NewType('Tensor', torch.Tensor) +Array = NewType('Array', np.ndarray) + + +@dataclass +class ModelOutput: + vertices: Optional[Tensor] = None + joints: Optional[Tensor] = None + full_pose: Optional[Tensor] = None + global_orient: Optional[Tensor] = None + transl: Optional[Tensor] = None + v_shaped: Optional[Tensor] = None + + def __getitem__(self, key): + return getattr(self, key) + + def get(self, key, default=None): + return getattr(self, key, default) + + def __iter__(self): + return self.keys() + + def keys(self): + keys = [t.name for t in fields(self)] + return iter(keys) + + def values(self): + values = [getattr(self, t.name) for t in fields(self)] + return iter(values) + + def items(self): + data = [(t.name, getattr(self, t.name)) for t in fields(self)] + return iter(data) + + +@dataclass +class SMPLOutput(ModelOutput): + betas: Optional[Tensor] = None + body_pose: Optional[Tensor] = None + + +@dataclass +class SMPLHOutput(SMPLOutput): + left_hand_pose: Optional[Tensor] = None + right_hand_pose: Optional[Tensor] = None + transl: Optional[Tensor] = None + + +@dataclass +class SMPLXOutput(SMPLHOutput): + expression: Optional[Tensor] = None + jaw_pose: Optional[Tensor] = None + + +@dataclass +class MANOOutput(ModelOutput): + betas: Optional[Tensor] = None + hand_pose: Optional[Tensor] = None + + +@dataclass +class FLAMEOutput(ModelOutput): + betas: Optional[Tensor] = None + expression: Optional[Tensor] = None + jaw_pose: Optional[Tensor] = None + neck_pose: Optional[Tensor] = None + + +def find_joint_kin_chain(joint_id, kinematic_tree): + kin_chain = [] + curr_idx = joint_id + while curr_idx != -1: + kin_chain.append(curr_idx) + curr_idx = kinematic_tree[curr_idx] + return kin_chain + + +def to_tensor( + array: Union[Array, Tensor], dtype=torch.float32 +) -> Tensor: + if torch.is_tensor(array): + return array + else: + return torch.tensor(array, dtype=dtype) + + +class Struct(object): + def __init__(self, **kwargs): + for key, val in kwargs.items(): + setattr(self, key, val) + + +def to_np(array, dtype=np.float32): + if 'scipy.sparse' in str(type(array)): + array = array.todense() + return np.array(array, dtype=dtype) + + +def rot_mat_to_euler(rot_mats): + # Calculates rotation matrix to euler angles + # Careful for extreme cases of eular angles like [0.0, pi, 0.0] + + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + return torch.atan2(-rot_mats[:, 2, 0], sy) diff --git a/SMPLX/smplx/vertex_ids.py b/SMPLX/smplx/vertex_ids.py new file mode 100644 index 0000000000000000000000000000000000000000..0e7a4c36700f002da54a9e181eabbd47af2a95bc --- /dev/null +++ b/SMPLX/smplx/vertex_ids.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import print_function +from __future__ import absolute_import +from __future__ import division + +# Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to +# MSCOCO and OpenPose joints +vertex_ids = { + 'smplh': { + 'nose': 332, + 'reye': 6260, + 'leye': 2800, + 'rear': 4071, + 'lear': 583, + 'rthumb': 6191, + 'rindex': 5782, + 'rmiddle': 5905, + 'rring': 6016, + 'rpinky': 6133, + 'lthumb': 2746, + 'lindex': 2319, + 'lmiddle': 2445, + 'lring': 2556, + 'lpinky': 2673, + 'LBigToe': 3216, + 'LSmallToe': 3226, + 'LHeel': 3387, + 'RBigToe': 6617, + 'RSmallToe': 6624, + 'RHeel': 6787 + }, + 'smplx': { + 'nose': 9120, + 'reye': 9929, + 'leye': 9448, + 'rear': 616, + 'lear': 6, + 'rthumb': 8079, + 'rindex': 7669, + 'rmiddle': 7794, + 'rring': 7905, + 'rpinky': 8022, + 'lthumb': 5361, + 'lindex': 4933, + 'lmiddle': 5058, + 'lring': 5169, + 'lpinky': 5286, + 'LBigToe': 5770, + 'LSmallToe': 5780, + 'LHeel': 8846, + 'RBigToe': 8463, + 'RSmallToe': 8474, + 'RHeel': 8635 + }, + 'mano': { + 'thumb': 744, + 'index': 320, + 'middle': 443, + 'ring': 554, + 'pinky': 671, + } +} diff --git a/SMPLX/smplx/vertex_joint_selector.py b/SMPLX/smplx/vertex_joint_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..17449726a45647709580c9b08f932a2e82cce2cc --- /dev/null +++ b/SMPLX/smplx/vertex_joint_selector.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np + +import torch +import torch.nn as nn + +from .utils import to_tensor + + +class VertexJointSelector(nn.Module): + + def __init__(self, vertex_ids=None, + use_hands=True, + use_feet_keypoints=True, **kwargs): + super(VertexJointSelector, self).__init__() + + extra_joints_idxs = [] + + face_keyp_idxs = np.array([ + vertex_ids['nose'], + vertex_ids['reye'], + vertex_ids['leye'], + vertex_ids['rear'], + vertex_ids['lear']], dtype=np.int64) + + extra_joints_idxs = np.concatenate([extra_joints_idxs, + face_keyp_idxs]) + + if use_feet_keypoints: + feet_keyp_idxs = np.array([vertex_ids['LBigToe'], + vertex_ids['LSmallToe'], + vertex_ids['LHeel'], + vertex_ids['RBigToe'], + vertex_ids['RSmallToe'], + vertex_ids['RHeel']], dtype=np.int32) + + extra_joints_idxs = np.concatenate( + [extra_joints_idxs, feet_keyp_idxs]) + + if use_hands: + self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky'] + + tips_idxs = [] + for hand_id in ['l', 'r']: + for tip_name in self.tip_names: + tips_idxs.append(vertex_ids[hand_id + tip_name]) + + extra_joints_idxs = np.concatenate( + [extra_joints_idxs, tips_idxs]) + + self.register_buffer('extra_joints_idxs', + to_tensor(extra_joints_idxs, dtype=torch.long)) + + def forward(self, vertices, joints): + extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs.to(torch.long)) #The '.to(torch.long)'. + # added to make the trace work in c++, + # otherwise you get a runtime error in c++: + # 'index_select(): Expected dtype int32 or int64 for index' + joints = torch.cat([joints, extra_joints], dim=1) + + return joints diff --git a/SMPLX/transfer_model/README.md b/SMPLX/transfer_model/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e7f14c4f7758c90c402d7c2155e536c3840c5eae --- /dev/null +++ b/SMPLX/transfer_model/README.md @@ -0,0 +1,253 @@ +# Model parameter transfer + +## Table of Contents + * [License](#license) + * [Description](#description) + * [Using the code](#using-the-code) + * [Data](#data) + * [Steps](#steps) + * [SMPL to SMPL-X](#smpl-to-smpl-x) + * [SMPL-X to SMPL](#smpl-x-to-smpl) + * [SMPL+H to SMPL](#smpl%2Bh-to-smpl) + * [SMPL to SMPL+H](#smpl-to-smpl%2Bh) + * [SMPL+H to SMPL-X](#smpl%2Bh-to-smpl-x) + * [SMPL-X to SMPL+H](#smpl-x-to-smpl%2Bh) + * [Visualize correspondences](visualize-correspondences) + * [Citation](#citation) + * [Acknowledgments](#acknowledgments) + * [Contact](#contact) + +## License + +Software Copyright License for **non-commercial scientific research purposes**. +Please read carefully the [terms and conditions](https://github.com/vchoutas/smplx/blob/master/LICENSE) and any accompanying documentation before you download and/or use the SMPL-X/SMPLify-X model, data and software, (the "Model & Software"), including 3D meshes, blend weights, blend shapes, textures, software, scripts, and animations. By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this [License](./LICENSE). + +## Description + +The repository contains code for converting model parameters of one model to +another. **Never** copy parameters between the models. You will not get the +same poses. SMPL, SMPL+H and SMPL-X shape spaces are **NOT** compatible, since +each model is the result of a different training process. +A more detailed explanation on how we extract correspondences +between the models and the loss function used to estimate the parameters can be +found [here](./docs/transfer.md). + +## Requirements + +1. Install [mesh](https://github.com/MPI-IS/mesh) +2. Start by cloning the SMPL-X repo: +```Shell +git clone https://github.com/vchoutas/smplx.git +``` +3. Run the following command to install all necessary requirements +```Shell + pip install -r requirements.txt +``` +4. Install the Torch Trust Region optimizer by following the instructions [here](https://github.com/vchoutas/torch-trust-ncg) +5. Install loguru +6. Install open3d +7. Install omegaconf + +## Using the code + +### Data + +Register on the [SMPL-X website](http://smpl-x.is.tue.mpg.de/), go to the +downloads section to get the correspondences and sample data, +by clicking on the *Model correspondences* button. +Create a folder +named `transfer_data` and extract the downloaded zip there. You should have the +following folder structure now: + +```bash +transfer_data +├── meshes +│   ├── smpl +│   ├── smplx +├── smpl2smplh_def_transfer.pkl +├── smpl2smplx_deftrafo_setup.pkl +├── smplh2smpl_def_transfer.pkl +├── smplh2smplx_deftrafo_setup.pkl +├── smplx2smpl_deftrafo_setup.pkl +├── smplx2smplh_deftrafo_setup.pkl +├── smplx_mask_ids.npy +``` + +### Steps + +First, break the motion into a set of pose `.obj` files. Depending on how the +SMPL-* parameters are stored this code will differ. For the example AMASS data +in this repository you can use the example code here: + +``` +python write_obj.py --model-folder ../models/ --motion-file ../transfer_data/support_data/github_data/amass_sample.npz --output-folder ../transfer_data/meshes/amass_sample/ +``` + +To run the `transfer_model` utility you will require a `.yaml` config file, +which can point to the location the output `.obj` files have been saved. Use the +templates in `config_files` in the root of this repository. To convert the +sample AMASS code to SMPL-X: + +``` +python -m transfer_model --exp-cfg config_files/smplh2smplx_as.yaml +``` + +Finally, the output `.obj` files have to be merged into a single motion +sequence. Example code to do this in a way that matches `SMPL-X` AMASS archives +can be found in `merge_output.py` and run as follows: + +``` +python merge_output.py --gender neutral ../output +``` + +Debug notes describing common problems encountered during this can be found +[here](https://github.com/gngdb/smplx/blob/debug/transfer_model/DEBUG_NOTES.md). +Problems are also discussed in +[two](https://github.com/vchoutas/smplx/issues/82) +[issues](https://github.com/vchoutas/smplx/issues/75). + +### SMPL to SMPL-X + +To run the code to convert SMPL meshes to SMPL-X parameters use the following command: + ```Shell + python -m transfer_model --exp-cfg config_files/smpl2smplx.yaml + ``` +This should be run from the top directory of the repository. + +The file *smpl2smplx.yaml* contains a sample configuration that reads meshes from a folder, +processes them and returns pkl files with SMPL-X parameters. To run on your own data create a folder +with SMPL meshes, in either ply or obj format, change the path in the config file and run the code. + +### SMPL-X to SMPL + +To run the code to convert SMPL-X meshes to SMPL parameters use the following command: + ```Shell + python main.py --exp-cfg config_files/smplx2smpl.yaml + ``` + +The file *smplx2smpl.yaml* contains a sample configuration that reads meshes from a folder, +processes them and returns pkl files with SMPL parameters. To run on your own data create a folder +with SMPL-X meshes, in either ply or obj format, change the path in the config file and run the code. +When creating the SMPL-X meshes, do not use the hand and face parameters. +Naturally, you will lose all hand and face information if you choose this, since +SMPL cannot model them. + + +### SMPL+H to SMPL + +To run the code to convert SMPL+H meshes to SMPL parameters use the following command from the root `smplx` directory: + ```Shell + python -m transfer_model --exp-cfg config_files/smplh2smpl.yaml + ``` +This should be run from the top directory of the repository. + +The file *smplh2smpl.yaml* contains a sample configuration that reads meshes from a folder, +processes them and returns pkl files with SMPL parameters. To run on your own data create a folder +with SMPL+H meshes, in either ply or obj format, change the path in the config file and run the code. +Note that using this direction means that you will lose information on the +hands. + + +### SMPL to SMPL+H + +To run the code to convert SMPL meshes to SMPL+H parameters use the following command: + ```Shell + python -m transfer_model --exp-cfg config_files/smpl2smplh.yaml + ``` +This should be run from the top directory of the repository. + +The file *smpl2smplh.yaml* contains a sample configuration that reads meshes from a folder, +processes them and returns pkl files with SMPL parameters. To run on your own data create a folder +with SMPL meshes, in either ply or obj format, change the path in the config file and run the code. + +### SMPL+H to SMPL-X + +To run the code to convert SMPL+H meshes to SMPL-X parameters use the following command: + ```Shell + python -m transfer_model --exp-cfg config_files/smplh2smplx.yaml + ``` +This should be run from the top directory of the repository. + +The file *smplh2smplx.yaml* contains a sample configuration that reads meshes from a folder, +processes them and returns pkl files with SMPL-X parameters. To run on your own data create a folder +with SMPL+H meshes, in either ply or obj format, change the path in the config file and run the code. + + +### SMPL-X to SMPL+H + +To run the code to convert SMPL-X meshes to SMPL+H parameters use the following command: + ```Shell + python -m transfer_model --exp-cfg config_files/smplx2smplh.yaml + ``` +This should be run from the top directory of the repository. + +The file *smplx2smpl.yaml* contains a sample configuration that reads meshes from a folder, +processes them and returns pkl files with SMPL+H parameters. To run on your own data create a folder +with SMPL-X meshes, in either ply or obj format, change the path in the config file and run the code. +Make sure that you do not use the jaw pose and expression parameters to generate +the meshes. + + +## Visualize correspondences + +To visualize correspondences: +```Shell +python vis_correspondences.py --exp-cfg configs/smpl2smplx.yaml --exp-opts colors_path PATH_TO_SMPL_COLORS +``` +You should then see the following image. Points with similar color are in +correspondence. +![Correspondence example](./docs/images/smpl_smplx_correspondence.png) + +## Citation + +Depending on which model is loaded for your project, i.e. SMPL-X or SMPL+H or SMPL, please cite the most relevant work: + +``` +@article{SMPL:2015, + author = {Loper, Matthew and Mahmood, Naureen and Romero, Javier and Pons-Moll, Gerard and Black, Michael J.}, + title = {{SMPL}: A Skinned Multi-Person Linear Model}, + journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)}, + month = oct, + number = {6}, + pages = {248:1--248:16}, + publisher = {ACM}, + volume = {34}, + year = {2015} +} +``` + +``` +@article{MANO:SIGGRAPHASIA:2017, + title = {Embodied Hands: Modeling and Capturing Hands and Bodies Together}, + author = {Romero, Javier and Tzionas, Dimitrios and Black, Michael J.}, + journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)}, + volume = {36}, + number = {6}, + pages = {245:1--245:17}, + series = {245:1--245:17}, + publisher = {ACM}, + month = nov, + year = {2017}, + url = {http://doi.acm.org/10.1145/3130800.3130883}, + month_numeric = {11} + } +``` + + +``` +@inproceedings{SMPL-X:2019, + title = {Expressive Body Capture: 3D Hands, Face, and Body from a Single Image}, + author = {Pavlakos, Georgios and Choutas, Vasileios and Ghorbani, Nima and Bolkart, Timo and Osman, Ahmed A. A. and Tzionas, Dimitrios and Black, Michael J.}, + booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, + year = {2019} +} +``` + + +## Acknowledgments +The code of this repository was implemented by [Vassilis Choutas](vassilis.choutas@tuebingen.mpg.de), +based on a Chumpy implementation from [Timo Bolkart](timo.bolkart@tuebingen.mpg.de). + +## Contact + +For questions, please contact [smplx@tue.mpg.de](smplx@tue.mpg.de). diff --git a/SMPLX/transfer_model/__init__.py b/SMPLX/transfer_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3c2dcee27aae54d7972ab9a4b41c2a4bb42c38 --- /dev/null +++ b/SMPLX/transfer_model/__init__.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import os +import os.path as osp +import sys +import pickle + +import numpy as np +import open3d as o3d +import torch +from loguru import logger +from tqdm import tqdm + +from SMPLX.smplx import build_layer + +from SMPLX.transfer_model.config import parse_args +from SMPLX.transfer_model.data import build_dataloader +from SMPLX.transfer_model.transfer_model import run_fitting +from SMPLX.transfer_model.utils import read_deformation_transfer, np_mesh_to_o3d + + +def trans_one_sequence(): + exp_cfg = parse_args() + + if torch.cuda.is_available() and exp_cfg["use_cuda"]: + device = torch.device('cuda') + else: + device = torch.device('cpu') + if exp_cfg["use_cuda"]: + if input("use_cuda=True and GPU is not available, using CPU instead," + " would you like to continue? (y/n)") != "y": + sys.exit(3) + + logger.remove() + logger.add( + lambda x: tqdm.write(x, end=''), level=exp_cfg.logger_level.upper(), + colorize=True) + + output_folder = osp.expanduser(osp.expandvars(exp_cfg.output_folder)) + logger.info(f'Saving output to: {output_folder}') + os.makedirs(output_folder, exist_ok=True) + + model_path = exp_cfg.body_model.folder + body_model = build_layer(model_path, **exp_cfg.body_model) + logger.info(body_model) + body_model = body_model.to(device=device) + + deformation_transfer_path = exp_cfg.get('deformation_transfer_path', '') + def_matrix = read_deformation_transfer( + deformation_transfer_path, device=device) + + # Read mask for valid vertex ids + mask_ids_fname = osp.expandvars(exp_cfg.mask_ids_fname) + mask_ids = None + if osp.exists(mask_ids_fname): + logger.info(f'Loading mask ids from: {mask_ids_fname}') + mask_ids = np.load(mask_ids_fname) + mask_ids = torch.from_numpy(mask_ids).to(device=device) + else: + logger.warning(f'Mask ids fname not found: {mask_ids_fname}') + + data_obj_dict = build_dataloader(exp_cfg) + + dataloader = data_obj_dict['dataloader'] + + for ii, batch in enumerate(tqdm(dataloader)): + for key in batch: + if torch.is_tensor(batch[key]): + batch[key] = batch[key].to(device=device) + var_dict = run_fitting( + exp_cfg, batch, body_model, def_matrix, mask_ids) + paths = batch['paths'] + + for ii, path in enumerate(paths): + _, fname = osp.split(path) + + output_path = osp.join( + output_folder, f'{osp.splitext(fname)[0]}.npz') + + for key in var_dict.keys(): + try: + var_dict[key] = var_dict[key].detach().cpu().numpy() + except: + pass + np.savez(output_path, **var_dict) \ No newline at end of file diff --git a/SMPLX/transfer_model/__main__.py b/SMPLX/transfer_model/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..13f5fc63843b7b47c182df42f82af10b78d43b43 --- /dev/null +++ b/SMPLX/transfer_model/__main__.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import os +import os.path as osp +import sys +import pickle + +import numpy as np +import open3d as o3d +import torch +from loguru import logger +from tqdm import tqdm + +from smplx import build_layer + +from .config import parse_args +from .data import build_dataloader +from .transfer_model import run_fitting +from .utils import read_deformation_transfer, np_mesh_to_o3d + + +def main() -> None: + exp_cfg = parse_args() + + if torch.cuda.is_available() and exp_cfg["use_cuda"]: + device = torch.device('cuda') + else: + device = torch.device('cpu') + if exp_cfg["use_cuda"]: + if input("use_cuda=True and GPU is not available, using CPU instead," + " would you like to continue? (y/n)") != "y": + sys.exit(3) + + logger.remove() + logger.add( + lambda x: tqdm.write(x, end=''), level=exp_cfg.logger_level.upper(), + colorize=True) + + output_folder = osp.expanduser(osp.expandvars(exp_cfg.output_folder)) + logger.info(f'Saving output to: {output_folder}') + os.makedirs(output_folder, exist_ok=True) + + model_path = exp_cfg.body_model.folder + body_model = build_layer(model_path, **exp_cfg.body_model) + logger.info(body_model) + body_model = body_model.to(device=device) + + deformation_transfer_path = exp_cfg.get('deformation_transfer_path', '') + def_matrix = read_deformation_transfer( + deformation_transfer_path, device=device) + + # Read mask for valid vertex ids + mask_ids_fname = osp.expandvars(exp_cfg.mask_ids_fname) + mask_ids = None + if osp.exists(mask_ids_fname): + logger.info(f'Loading mask ids from: {mask_ids_fname}') + mask_ids = np.load(mask_ids_fname) + mask_ids = torch.from_numpy(mask_ids).to(device=device) + else: + logger.warning(f'Mask ids fname not found: {mask_ids_fname}') + + data_obj_dict = build_dataloader(exp_cfg) + + dataloader = data_obj_dict['dataloader'] + + for ii, batch in enumerate(tqdm(dataloader)): + for key in batch: + if torch.is_tensor(batch[key]): + batch[key] = batch[key].to(device=device) + var_dict = run_fitting( + exp_cfg, batch, body_model, def_matrix, mask_ids) + paths = batch['paths'] + + for ii, path in enumerate(paths): + _, fname = osp.split(path) + + output_path = osp.join( + output_folder, f'{osp.splitext(fname)[0]}.npz') + + for key in var_dict.keys(): + try: + var_dict[key] = var_dict[key].detach().cpu().numpy() + except: + pass + np.savez(output_path, **var_dict) + + + +if __name__ == '__main__': + main() diff --git a/SMPLX/transfer_model/__pycache__/__init__.cpython-38.pyc b/SMPLX/transfer_model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1f5745be6d6dc7a877027e7629bb63f3f5fc757 Binary files /dev/null and b/SMPLX/transfer_model/__pycache__/__init__.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/__pycache__/__init__.cpython-39.pyc b/SMPLX/transfer_model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..363c3c6f397df6861aef7e4133559f1dbfb03eaf Binary files /dev/null and b/SMPLX/transfer_model/__pycache__/__init__.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/__pycache__/merge_output.cpython-38.pyc b/SMPLX/transfer_model/__pycache__/merge_output.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56740813d78fa5ba585bbbd494fe1d6ea929bfd3 Binary files /dev/null and b/SMPLX/transfer_model/__pycache__/merge_output.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/__pycache__/merge_output.cpython-39.pyc b/SMPLX/transfer_model/__pycache__/merge_output.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a5d5c1d831b5e2793dd49bc1b88a3ce1cc87905 Binary files /dev/null and b/SMPLX/transfer_model/__pycache__/merge_output.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/__pycache__/transfer_model.cpython-38.pyc b/SMPLX/transfer_model/__pycache__/transfer_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..662f47f5585036c2b3c3962fb8b3788fa47e811d Binary files /dev/null and b/SMPLX/transfer_model/__pycache__/transfer_model.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/__pycache__/transfer_model.cpython-39.pyc b/SMPLX/transfer_model/__pycache__/transfer_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afd29cf20bd9a092e68e17d56423a5040c4a1c7a Binary files /dev/null and b/SMPLX/transfer_model/__pycache__/transfer_model.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/__pycache__/write_obj.cpython-38.pyc b/SMPLX/transfer_model/__pycache__/write_obj.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36235b9a1faa1b869bab6f708a162a913aa1adbe Binary files /dev/null and b/SMPLX/transfer_model/__pycache__/write_obj.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/__pycache__/write_obj.cpython-39.pyc b/SMPLX/transfer_model/__pycache__/write_obj.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e57b494954cfbda5ded3626a9ac970e7249dbdf Binary files /dev/null and b/SMPLX/transfer_model/__pycache__/write_obj.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/config/__init__.py b/SMPLX/transfer_model/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4711a91455b85e0da688c1a3ae6573fa1f36187d --- /dev/null +++ b/SMPLX/transfer_model/config/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from .cmd_parser import parse_args diff --git a/SMPLX/transfer_model/config/__pycache__/__init__.cpython-38.pyc b/SMPLX/transfer_model/config/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa2b59edcb6a77f208b3e94b391624c98a5e4e83 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/__init__.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/__init__.cpython-39.pyc b/SMPLX/transfer_model/config/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a996a7a138be493d3ccd26be83e92918d41b08bc Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/__init__.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/body_model_defaults.cpython-38.pyc b/SMPLX/transfer_model/config/__pycache__/body_model_defaults.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c75d224d1ba8bf091309013d08dc54632197e97 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/body_model_defaults.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/body_model_defaults.cpython-39.pyc b/SMPLX/transfer_model/config/__pycache__/body_model_defaults.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a2a6342efcc3e3ac50f46291e1357e53f7e287e Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/body_model_defaults.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/cmd_parser.cpython-38.pyc b/SMPLX/transfer_model/config/__pycache__/cmd_parser.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e1c5826db3085b617b70bc4925437fbaca9e49f Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/cmd_parser.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/cmd_parser.cpython-39.pyc b/SMPLX/transfer_model/config/__pycache__/cmd_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce9ed7d5e7e84e4da47af3952fd9c4acc4ed86a5 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/cmd_parser.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/dataset_defaults.cpython-38.pyc b/SMPLX/transfer_model/config/__pycache__/dataset_defaults.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa64fb121967daed31a12545a4e8215c93592d5 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/dataset_defaults.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/dataset_defaults.cpython-39.pyc b/SMPLX/transfer_model/config/__pycache__/dataset_defaults.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0cebfb3e8226edb87dbdbbf6e92981452e5d075 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/dataset_defaults.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/defaults.cpython-38.pyc b/SMPLX/transfer_model/config/__pycache__/defaults.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4011689caa613f44b85200c181440aa7cf5ed73 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/defaults.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/defaults.cpython-39.pyc b/SMPLX/transfer_model/config/__pycache__/defaults.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2e06e97dfd8b05b9ef2eb1e084dce63e7a3b530 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/defaults.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/loss_defaults.cpython-38.pyc b/SMPLX/transfer_model/config/__pycache__/loss_defaults.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11c9e405dccaec4c7000f3d5c4490d1cf89797f3 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/loss_defaults.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/loss_defaults.cpython-39.pyc b/SMPLX/transfer_model/config/__pycache__/loss_defaults.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cb2a48a80447e0ec4302c48749a10167493aabc Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/loss_defaults.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/optim_defaults.cpython-38.pyc b/SMPLX/transfer_model/config/__pycache__/optim_defaults.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1119a3f0636485aeeca71f7ebc05749e0d3e8df1 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/optim_defaults.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/optim_defaults.cpython-39.pyc b/SMPLX/transfer_model/config/__pycache__/optim_defaults.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa6b020d33bf743530d2696a981eadbbd81e8c9e Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/optim_defaults.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/utils_cfg.cpython-38.pyc b/SMPLX/transfer_model/config/__pycache__/utils_cfg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b64f630eec4d7fc982deab05f5f3fcc805e738d0 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/utils_cfg.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/config/__pycache__/utils_cfg.cpython-39.pyc b/SMPLX/transfer_model/config/__pycache__/utils_cfg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7917cc389ba3619a6a4f728aa9d88d325cc9821 Binary files /dev/null and b/SMPLX/transfer_model/config/__pycache__/utils_cfg.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/config/body_model_defaults.py b/SMPLX/transfer_model/config/body_model_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..658149cced69097c70f92d3d1e6f8500187aa978 --- /dev/null +++ b/SMPLX/transfer_model/config/body_model_defaults.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from omegaconf import OmegaConf +from loguru import logger +from dataclasses import dataclass +from .utils_cfg import Variable, Pose + + +@dataclass +class PCA: + num_comps: int = 12 + flat_hand_mean: bool = False + + +@dataclass +class PoseWithPCA(Pose): + pca: PCA = PCA() + + +@dataclass +class Shape(Variable): + num: int = 10 + + +@dataclass +class Expression(Variable): + num: int = 10 + + +@dataclass +class SMPL: + betas: Shape = Shape() + global_rot: Pose = Pose() + body_pose: Pose = Pose() + translation: Variable = Variable() + + +@dataclass +class SMPLH(SMPL): + left_hand_pose: PoseWithPCA = PoseWithPCA() + right_hand_pose: PoseWithPCA = PoseWithPCA() + + +@dataclass +class SMPLX(SMPLH): + expression: Expression = Expression() + jaw_pose: Pose = Pose() + leye_pose: Pose = Pose() + reye_pose: Pose = Pose() + + +@dataclass +class MANO: + betas: Shape = Shape() + wrist_pose: Pose = Pose() + hand_pose: PoseWithPCA = PoseWithPCA() + translation: Variable = Variable() + + +@dataclass +class FLAME: + betas: Shape = Shape() + expression: Expression = Expression() + global_rot: Pose = Pose() + neck_pose: Pose = Pose() + jaw_pose: Pose = Pose() + leye_pose: Pose = Pose() + reye_pose: Pose = Pose() + + +@dataclass +class BodyModelConfig: + model_type: str = 'smplx' + use_compressed: bool = True + folder: str = 'models' + gender: str = 'neutral' + extra_joint_path: str = '' + ext: str = 'npz' + + num_expression_coeffs: int = 10 + + use_face_contour: bool = True + joint_regressor_path: str = '' + + smpl: SMPL = SMPL() + star: SMPL = SMPL() + smplh: SMPLH = SMPLH() + smplx: SMPLX = SMPLX() + mano: MANO = MANO() + flame: FLAME = FLAME() + + +conf = OmegaConf.structured(BodyModelConfig) diff --git a/SMPLX/transfer_model/config/cmd_parser.py b/SMPLX/transfer_model/config/cmd_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..474903e813e238d20e421b916a835d449babb960 --- /dev/null +++ b/SMPLX/transfer_model/config/cmd_parser.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import division + +import sys +import os + +import argparse +from loguru import logger + +from omegaconf import OmegaConf +from .defaults import conf as default_conf + + +def parse_args(argv=None) -> OmegaConf: + arg_formatter = argparse.ArgumentDefaultsHelpFormatter + + description = 'Model transfer script' + parser = argparse.ArgumentParser(formatter_class=arg_formatter, + description=description) + + parser.add_argument('--exp-cfg', type=str, dest='exp_cfg', + help='The configuration of the experiment') + parser.add_argument('--exp-opts', default=[], dest='exp_opts', + nargs='*', + help='Command line arguments') + + cmd_args = parser.parse_args() + + cfg = default_conf.copy() + if cmd_args.exp_cfg: + cfg.merge_with(OmegaConf.load(cmd_args.exp_cfg)) + if cmd_args.exp_opts: + cfg.merge_with(OmegaConf.from_cli(cmd_args.exp_opts)) + + return cfg diff --git a/SMPLX/transfer_model/config/dataset_defaults.py b/SMPLX/transfer_model/config/dataset_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..ce3e4f36a6e7f7f7c54948c68a18ee97a5ef2d00 --- /dev/null +++ b/SMPLX/transfer_model/config/dataset_defaults.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from omegaconf import OmegaConf +from dataclasses import dataclass + + +@dataclass +class MeshFolder: + data_folder: str = 'data/meshes' + + +@dataclass +class DatasetConfig: + num_workers: int = 0 + name: str = 'mesh-folder' + mesh_folder: MeshFolder = MeshFolder() + + +conf = OmegaConf.structured(DatasetConfig) diff --git a/SMPLX/transfer_model/config/defaults.py b/SMPLX/transfer_model/config/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..1656c21c4b4df3827554e8f45d2aea0fd7327415 --- /dev/null +++ b/SMPLX/transfer_model/config/defaults.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from typing import Tuple, Optional +from copy import deepcopy +# from yacs.config import CfgNode as CN +from dataclasses import dataclass +from omegaconf import OmegaConf + +from .loss_defaults import conf as loss_cfg, LossConfig +from .dataset_defaults import conf as dataset_cfg, DatasetConfig +from .optim_defaults import conf as optim_cfg, OptimConfig +from .body_model_defaults import conf as body_model_cfg, BodyModelConfig + + +@dataclass +class EdgeFitting: + per_part: bool = False + reduction: str = 'mean' + + +@dataclass +class VertexFitting: + per_part: bool = False + reduction: str = 'mean' + type: str = 'l2' + + +@dataclass +class Config: + use_cuda: bool = True + log_file: str = '/tmp/logs' + output_folder: str = 'output' + save_verts: bool = True + save_joints: bool = True + save_mesh: bool = False + save_img_summaries: bool = True + summary_steps: int = 5 + degrees: Tuple[float] = (90,) + float_type: str = 'float' + logger_level: str = 'INFO' + interactive: bool = True + batch_size: Optional[int] = 1 + color_path: str = 'data/smpl_with_colors.ply' + + optim: OptimConfig = optim_cfg + datasets: DatasetConfig = dataset_cfg + losses: LossConfig = loss_cfg + body_model: BodyModelConfig = body_model_cfg + + deformation_transfer_path: str = '' + mask_ids_fname: str = '' + + per_part: bool = True + edge_fitting: EdgeFitting = EdgeFitting() + + +conf = OmegaConf.structured(Config) diff --git a/SMPLX/transfer_model/config/loss_defaults.py b/SMPLX/transfer_model/config/loss_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..14fa090525eb69cde9079259bc176a1f71af713c --- /dev/null +++ b/SMPLX/transfer_model/config/loss_defaults.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de +# from yacs.config import CfgNode as CN + +from typing import List, Tuple, Union +from omegaconf import OmegaConf +from loguru import logger +from dataclasses import dataclass, make_dataclass + + +@dataclass +class LossTemplate: + type: str = 'l2' + active: bool = False + weight: Tuple[float] = (0.0,) + requires_grad: bool = True + enable: int = 0 + + +@dataclass +class LossConfig: + type: str = 'smplify-x' + + +conf = OmegaConf.structured(LossConfig) diff --git a/SMPLX/transfer_model/config/optim_defaults.py b/SMPLX/transfer_model/config/optim_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfc6accc586c8f8ec4fbe2315ca4312885b1714 --- /dev/null +++ b/SMPLX/transfer_model/config/optim_defaults.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from typing import Tuple +from omegaconf import OmegaConf +from dataclasses import dataclass + + +@dataclass +class LBFGS: + line_search_fn: str = 'strong_wolfe' + max_iter: int = 50 + + +@dataclass +class SGD: + momentum: float = 0.9 + nesterov: bool = True + + +@dataclass +class ADAM: + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-08 + amsgrad: bool = False + + +@dataclass +class RMSProp: + alpha: float = 0.99 + + +@dataclass +class TrustRegionNewtonCG: + max_trust_radius: float = 1000 + initial_trust_radius: float = 0.05 + eta: float = 0.15 + gtol: float = 1e-05 + + +@dataclass +class OptimConfig: + type: str = 'trust-ncg' + lr: float = 1.0 + gtol: float = 1e-8 + ftol: float = -1.0 + maxiters: int = 100 + + lbfgs: LBFGS = LBFGS() + sgd: SGD = SGD() + adam: ADAM = ADAM() + trust_ncg: TrustRegionNewtonCG = TrustRegionNewtonCG() + + +conf = OmegaConf.structured(OptimConfig) diff --git a/SMPLX/transfer_model/config/utils_cfg.py b/SMPLX/transfer_model/config/utils_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea224389b2b9eff299665c384e236844b9e8f80 --- /dev/null +++ b/SMPLX/transfer_model/config/utils_cfg.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from typing import Tuple +from dataclasses import dataclass + + +@dataclass +class Variable: + create: bool = True + requires_grad: bool = True + + +@dataclass +class Pose(Variable): + type: str = 'aa' diff --git a/SMPLX/transfer_model/data/__init__.py b/SMPLX/transfer_model/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2892d560e3e9aa8c9b68186b537f35d04fe3bec --- /dev/null +++ b/SMPLX/transfer_model/data/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from .build import build_dataloader diff --git a/SMPLX/transfer_model/data/__pycache__/__init__.cpython-38.pyc b/SMPLX/transfer_model/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aecd3747cff13c8212d1b617eed21d10e7a1152c Binary files /dev/null and b/SMPLX/transfer_model/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/data/__pycache__/__init__.cpython-39.pyc b/SMPLX/transfer_model/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31be8bd4ffe82e38b9e0fef5ebc970b0531ce1b3 Binary files /dev/null and b/SMPLX/transfer_model/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/data/__pycache__/build.cpython-38.pyc b/SMPLX/transfer_model/data/__pycache__/build.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb2f209f1642b01d40e3590a6c6348509a9315ce Binary files /dev/null and b/SMPLX/transfer_model/data/__pycache__/build.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/data/__pycache__/build.cpython-39.pyc b/SMPLX/transfer_model/data/__pycache__/build.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1182ef6cd41b3d05ee70b031314b4ff25c6bd34f Binary files /dev/null and b/SMPLX/transfer_model/data/__pycache__/build.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/data/build.py b/SMPLX/transfer_model/data/build.py new file mode 100644 index 0000000000000000000000000000000000000000..f1efe7ffcda7afcc2ef68fa2d191655eac80019d --- /dev/null +++ b/SMPLX/transfer_model/data/build.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from typing import List, Tuple +import sys + +import torch +import torch.utils.data as dutils +from .datasets import MeshFolder + +from loguru import logger + +def build_dataloader(datasets): + mesh_folder_cfg = datasets["mesh_folder"] + key, *_ = mesh_folder_cfg.keys() + value = mesh_folder_cfg[key] + logger.info(f'{key}: {value}\n') + dataset = MeshFolder(**mesh_folder_cfg) + + batch_size = datasets["batch_size"] + num_workers = 1 + + logger.info( + f'Creating dataloader with B={batch_size}, workers={num_workers}') + dataloader = dutils.DataLoader(dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False) + + return {'dataloader': dataloader, 'dataset': dataset} + + +# def build_dataloader(exp_cfg): +# dset_name = exp_cfg.datasets.name +# if dset_name == 'mesh-folder': +# mesh_folder_cfg = exp_cfg.datasets.mesh_folder +# key, *_ = mesh_folder_cfg.keys() +# value = mesh_folder_cfg[key] +# logger.info(f'{key}: {value}\n') +# dataset = MeshFolder(**mesh_folder_cfg) +# else: +# raise ValueError(f'Unknown dataset: {dset_name}') + +# import pdb;pdb.set_trace() +# batch_size = exp_cfg.batch_size +# num_workers = exp_cfg.datasets.num_workers + +# logger.info( +# f'Creating dataloader with B={batch_size}, workers={num_workers}') +# dataloader = dutils.DataLoader(dataset, +# batch_size=batch_size, +# num_workers=num_workers, +# shuffle=False) + +# return {'dataloader': dataloader, 'dataset': dataset} diff --git a/SMPLX/transfer_model/data/datasets/__init__.py b/SMPLX/transfer_model/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e82ef01b229281dcae08e4e2cf56e5f6d5cb73 --- /dev/null +++ b/SMPLX/transfer_model/data/datasets/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from .mesh import MeshFolder diff --git a/SMPLX/transfer_model/data/datasets/__pycache__/__init__.cpython-38.pyc b/SMPLX/transfer_model/data/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1674317097e00c430395e9fb9074a966c1a070f Binary files /dev/null and b/SMPLX/transfer_model/data/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/data/datasets/__pycache__/__init__.cpython-39.pyc b/SMPLX/transfer_model/data/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..235b2e16c511be534126824e25871e558ab73635 Binary files /dev/null and b/SMPLX/transfer_model/data/datasets/__pycache__/__init__.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/data/datasets/__pycache__/mesh.cpython-38.pyc b/SMPLX/transfer_model/data/datasets/__pycache__/mesh.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0f4938c1edeaf2318e979037b0fc671284d2279 Binary files /dev/null and b/SMPLX/transfer_model/data/datasets/__pycache__/mesh.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/data/datasets/__pycache__/mesh.cpython-39.pyc b/SMPLX/transfer_model/data/datasets/__pycache__/mesh.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eb8145980503f27263096f61194d52ebc74e6e6 Binary files /dev/null and b/SMPLX/transfer_model/data/datasets/__pycache__/mesh.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/data/datasets/mesh.py b/SMPLX/transfer_model/data/datasets/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..e15d8b908cd1e9b477b658eb70cbd0adeb8226a2 --- /dev/null +++ b/SMPLX/transfer_model/data/datasets/mesh.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from typing import Optional, Tuple + +import sys +import os +import os.path as osp + +import numpy as np +from psbody.mesh import Mesh +import trimesh + +import torch +from torch.utils.data import Dataset +from loguru import logger + + +class MeshFolder(Dataset): + def __init__( + self, + data_folder: str, + transforms=None, + exts: Optional[Tuple] = None + ) -> None: + ''' Dataset similar to ImageFolder that reads meshes with the same + topology + ''' + if exts is None: + exts = ['.obj', '.ply'] + + self.data_folder = osp.expandvars(data_folder) + + logger.info( + f'Building mesh folder dataset for folder: {self.data_folder}') + + self.data_paths = np.array([ + osp.join(self.data_folder, fname) + for fname in os.listdir(self.data_folder) + if any(fname.endswith(ext) for ext in exts) + ]) + self.num_items = len(self.data_paths) + + def __len__(self) -> int: + return self.num_items + + def __getitem__(self, index): + mesh_path = self.data_paths[index] + + # Load the mesh + mesh = trimesh.load(mesh_path, process=False) + + return { + 'vertices': np.asarray(mesh.vertices, dtype=np.float32), + 'faces': np.asarray(mesh.faces, dtype=np.int32), + 'indices': index, + 'paths': mesh_path, + } diff --git a/SMPLX/transfer_model/docs/images/smpl_smplx_correspondence.png b/SMPLX/transfer_model/docs/images/smpl_smplx_correspondence.png new file mode 100644 index 0000000000000000000000000000000000000000..305580d53626911ddaa3e5c73175e05fa7fe22fc Binary files /dev/null and b/SMPLX/transfer_model/docs/images/smpl_smplx_correspondence.png differ diff --git a/SMPLX/transfer_model/docs/transfer.md b/SMPLX/transfer_model/docs/transfer.md new file mode 100644 index 0000000000000000000000000000000000000000..91947fe9038c0b793c2fc59ee0c9a99dc3b0f59d --- /dev/null +++ b/SMPLX/transfer_model/docs/transfer.md @@ -0,0 +1,115 @@ +# Converting SMPL to SMPL-X + + + + + + + + + + + + +The SMPL body model [1] is in wide use in computer vision and graphics for both +research and industrial applications. While widely used, SMPL lacks details like +articulated hands and an expressive face. The SMPL-X model [3] addresses this +and includes both the face and hands. + +Many legacy applications and datasets are built on SMPL and people want to +"upgrade" them to SMPL-X. While SMPL-X is based on the SMPL technology, they are +not completely interchangeable. + +Importantly the shape and pose parameters of SMPL and SMPL-X seem tantalizingly +similar. Sadly, you can't just take them from one model and use them with the +other. In particular, the joint locations in SMPL-X differ from those in SMPL, +meaning that the pose (theta) parameters are not interchangeable. + +Here we describe a tool to convert back and forth between the models. This +involves fitting one model to the other to recover the right parameters. + +The first step in this process is to establish a mapping between SMPL and +SMPL-X, since their topologies differ. For this, we assume we have a SMPL-X +template mesh registered to the SMPL template. Now that the two surfaces match, +we compute and store the following quantities: + +* For each SMPL-X vertex find the nearest point on the SMPL mesh and store: + * The index $t_i$ of the triangle where the nearest point is located. + * Store the barycentric coordinates of the nearest point with respect to + the SMPL triangle $\left[a_i, b_i, c_i\right]$. + + +SMPL-X and SMPL share the same topology up to the neck, therefore the Barycentric coordinates of +these points are a permutation of `[1.0, 0.0, 0.0]`. We also store a mask of +valid vertices, to remove points that have no match between the two meshes, +such as the eyeballs or the inner mouth. If we color-code the correspondences +we end up with the following image, where the left mesh is SMPL and the right +one is SMPL-X: + +![Correspondences](./images/smpl_smplx_correspondence.png) + +Now that we have established the correspondences between the models, we can fit +SMPL-X to the SMPL annotations. +1. The first step is to build a mesh with the SMPL-X topology from the posed + SMPL annotations. + + 1. If $t_i$ is the index of the corresponding SMPL triangle for the i-th SMPL-X + vertex, then let $f_i \in \mathbb{N}^3$ be the 3 indices of the SMPL vertices that + form the triangle. + 2. Let $m_i$ be the binary mask value for the validity of this vertex. + 2. The i-th vertex is computed using the barycentrics $\left[a_i, b_i, c_i\right]$ as: + + $v_i^{SMPL-X} = a_i * v_{f_i^0}^{SMPL} + b_i * v_{f_i^1}^{SMPL} + c_i * v_{f_i^2}^{SMPL}$ + + 2. Now that we have a mesh in SMPL-X topology, we need to find the SMPL-X + parameters, i.e. pose $\theta$, shape $\beta$, expression $\psi$ and translation $\gamma$, that best explain it. + We use an iterative optimization scheme to + recover the parameters: + + 1. Optimize over the pose with a 3D edge term. Make sure that we only use + the valid edges, i.e. those whose both end points are found on both + meshes: + + $L_1\left(\theta\right) = \sum_{(i, j) \in \mathcal{E}} m_i m_j \left\lVert(v_i - v_j) - (\hat{v}_i - \hat{v}_j) \right\rVert_2^2$ + + 2. Optimize over the translation vector $\gamma$ to align the two models: + + $L_2\left({\gamma}\right) = \sum_{i} m_i \left\lVert v_i - \hat{v}_i \right\rVert$ + + 3. Optimize over all parameters, to get the tightest possible fit: + + $L_3\left((\theta, \beta, \psi, \gamma)\right) = \sum_{i} m_i \left\lVert v_i - \hat{v}_i \right\rVert_2^2$ + + +So now, if you have data in SMPL format, you can convert it to SMPL-X. This +should allow you to use it for training. + +For the inverse mapping, from SMPL-X to +SMPL, we follow a similar process to generate the correspondences and then optimize +over the SMPL parameters that best fit the +transferred mesh. Of course, if you choose to do this, you will lose all +information about the hands and the face, since SMPL is not able to model this. + +For SMPL and SMPL+H [2], the process is easier, since they share the same +topology. We can therefore skip the first step, since we already know the +correspondences, compute a SMPL or SMPL+H mesh and estimate the parameters of +the other model. If we wish to transfer SMPL+H annotations, such as the AMASS +motion capture data [4], to SMPL-X, then we can use the correspondences of the +SMPL to SMPL-X mapping. + +## Bibliography + +[1]: Loper, M., Mahmood, N., Romero, J., Pons-Moll, G., Black, M.J.: SMPL: A +skinned multi-person linear model. ACM Transactions on Graphics (TOG) - Proceedings of ACM SIGGRAPH Asia 34(6), 248:1–248:16 (2015) + +[2]: Romero, J., Tzionas, D., Black, M.J.: Embodied hands: Modeling and capturing +hands and bodies together. ACM Transactions on Graphics (TOG) - Proceedings +of ACM SIGGRAPH Asia 36(6), 245:1–245:17 (2017) + +[3]: Pavlakos, G., Choutas, V., Ghorbani, N., Bolkart, T., Osman, A.A.A., Tzionas, +D., Black, M.J.: Expressive body capture: 3D hands, face, and body from a single +image. In: Proceedings of the IEEE Conference on Computer Vision and Pattern +Recognition (CVPR). pp. 10967–10977 (2019) + +[4]: Mahmood, N., Ghorbani, N., Troje, N.F., Pons-Moll, G., Black, M.J.: Amass: +Archive of motion capture as surface shapes. ICCV (2019) diff --git a/SMPLX/transfer_model/losses/__init__.py b/SMPLX/transfer_model/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d991ba0ddadfe1824d06d1bb336d52f9416d9a63 --- /dev/null +++ b/SMPLX/transfer_model/losses/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from .losses import * diff --git a/SMPLX/transfer_model/losses/__pycache__/__init__.cpython-38.pyc b/SMPLX/transfer_model/losses/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdea11489f352f3408ba5eb9f61517dd7ab2dc51 Binary files /dev/null and b/SMPLX/transfer_model/losses/__pycache__/__init__.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/losses/__pycache__/__init__.cpython-39.pyc b/SMPLX/transfer_model/losses/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16732db8852e06dc4a056d4ff505b76f2046eebd Binary files /dev/null and b/SMPLX/transfer_model/losses/__pycache__/__init__.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/losses/__pycache__/losses.cpython-38.pyc b/SMPLX/transfer_model/losses/__pycache__/losses.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e65bdc27c09a39237eb1d1bad5c00395f63c51a5 Binary files /dev/null and b/SMPLX/transfer_model/losses/__pycache__/losses.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/losses/__pycache__/losses.cpython-39.pyc b/SMPLX/transfer_model/losses/__pycache__/losses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eb777a447b8b31a1a0284d5c8ff9a8ce093f5f6 Binary files /dev/null and b/SMPLX/transfer_model/losses/__pycache__/losses.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/losses/__pycache__/utils.cpython-38.pyc b/SMPLX/transfer_model/losses/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51197ed2481fb0c481aedba884a430ec591bbb42 Binary files /dev/null and b/SMPLX/transfer_model/losses/__pycache__/utils.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/losses/__pycache__/utils.cpython-39.pyc b/SMPLX/transfer_model/losses/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..714542ff628e5222a2eb22d7ecfbc193b228791b Binary files /dev/null and b/SMPLX/transfer_model/losses/__pycache__/utils.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/losses/losses.py b/SMPLX/transfer_model/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..57d96fd26889fe31cad62a8eec4e0e08f478870e --- /dev/null +++ b/SMPLX/transfer_model/losses/losses.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de +from __future__ import print_function +from __future__ import absolute_import +from __future__ import division + +import sys +import time +from typing import Callable, Iterator, Union, Optional, List + +import os.path as osp +import yaml +from loguru import logger + +import pickle + +import numpy as np + +import torch +import torch.autograd as autograd +import torch.nn as nn +import torch.nn.functional as F + +from .utils import get_reduction_method + +__all__ = [ + 'VertexEdgeLoss', + 'build_loss', +] + + +def build_loss(type='l2', reduction='mean', **kwargs) -> nn.Module: + logger.debug(f'Building loss: {type}') + if type == 'l2': + return WeightedMSELoss(reduction=reduction, **kwargs) + elif type == 'vertex-edge': + return VertexEdgeLoss(reduction=reduction, **kwargs) + elif type == 'l1': + return nn.L1Loss() + else: + raise ValueError(f'Unknown loss type: {type}') + + +class WeightedMSELoss(nn.Module): + def __init__(self, reduction='mean', **kwargs): + super(WeightedMSELoss, self).__init__() + self.reduce_str = reduction + self.reduce = get_reduction_method(reduction) + + def forward(self, input, target, weights=None): + diff = input - target + if weights is None: + return diff.pow(2).sum() / diff.shape[0] + else: + return ( + weights.unsqueeze(dim=-1) * diff.pow(2)).sum() / diff.shape[0] + + +class VertexEdgeLoss(nn.Module): + def __init__(self, norm_type='l2', + gt_edges=None, + gt_edge_path='', + est_edges=None, + est_edge_path='', + robustifier=None, + edge_thresh=0.0, epsilon=1e-8, + reduction='sum', + **kwargs): + super(VertexEdgeLoss, self).__init__() + + assert norm_type in ['l1', 'l2'], 'Norm type must be [l1, l2]' + self.norm_type = norm_type + self.epsilon = epsilon + self.reduction = reduction + assert self.reduction in ['sum', 'mean'] + logger.info(f'Building edge loss with' + f' norm_type={norm_type},' + f' reduction={reduction},' + ) + + gt_edge_path = osp.expandvars(gt_edge_path) + est_edge_path = osp.expandvars(est_edge_path) + assert osp.exists(gt_edge_path) or gt_edges is not None, ( + 'gt_edges must not be None or gt_edge_path must exist' + ) + assert osp.exists(est_edge_path) or est_edges is not None, ( + 'est_edges must not be None or est_edge_path must exist' + ) + if osp.exists(gt_edge_path) and gt_edges is None: + gt_edges = np.load(gt_edge_path) + if osp.exists(est_edge_path) and est_edges is None: + est_edges = np.load(est_edge_path) + + self.register_buffer( + 'gt_connections', torch.tensor(gt_edges, dtype=torch.long)) + self.register_buffer( + 'est_connections', torch.tensor(est_edges, dtype=torch.long)) + + def extra_repr(self): + msg = [ + f'Norm type: {self.norm_type}', + ] + if self.has_connections: + msg.append( + f'GT Connections shape: {self.gt_connections.shape}' + ) + msg.append( + f'Est Connections shape: {self.est_connections.shape}' + ) + return '\n'.join(msg) + + def compute_edges(self, points, connections): + edge_points = torch.index_select( + points, 1, connections.view(-1)).reshape(points.shape[0], -1, 2, 3) + return edge_points[:, :, 1] - edge_points[:, :, 0] + + def forward(self, gt_vertices, est_vertices, weights=None): + gt_edges = self.compute_edges( + gt_vertices, connections=self.gt_connections) + est_edges = self.compute_edges( + est_vertices, connections=self.est_connections) + + raw_edge_diff = (gt_edges - est_edges) + + batch_size = gt_vertices.shape[0] + if self.norm_type == 'l2': + edge_diff = raw_edge_diff.pow(2) + elif self.norm_type == 'l1': + edge_diff = raw_edge_diff.abs() + else: + raise NotImplementedError( + f'Loss type not implemented: {self.loss_type}') + if self.reduction == 'sum': + return edge_diff.sum() + elif self.reduction == 'mean': + return edge_diff.sum() / batch_size diff --git a/SMPLX/transfer_model/losses/utils.py b/SMPLX/transfer_model/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..99d9feeaba888130ae22494bf47e450560eac584 --- /dev/null +++ b/SMPLX/transfer_model/losses/utils.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import torch + +def get_reduction_method(reduction='mean'): + if reduction == 'mean': + return torch.mean + elif reduction == 'sum': + return torch.sum + elif reduction == 'none': + return lambda x: x + else: + raise ValueError('Unknown reduction method: {}'.format(reduction)) diff --git a/SMPLX/transfer_model/merge_output.py b/SMPLX/transfer_model/merge_output.py new file mode 100644 index 0000000000000000000000000000000000000000..8d9dcba08df9fcc99b89da1ea37118ff10cb1c74 --- /dev/null +++ b/SMPLX/transfer_model/merge_output.py @@ -0,0 +1,83 @@ +# merges the output of the main transfer_model script + +import torch +from pathlib import Path +import pickle +from scipy.spatial.transform import Rotation as R +import numpy as np +KEYS = [ +"transl", +"betas", +"full_pose", +] + + +def aggregate_rotmats(x): + x = np.concatenate(x, axis=0) + s = x.shape[:-2] + try: + x = R.from_matrix(x.reshape(-1, 3, 3)).as_rotvec() + except: + pass + x = x.reshape(s[0], -1) + return x + +aggregate_function = {k: lambda x: np.concatenate(x, axis=0) for k in KEYS} +aggregate_function["betas"] = lambda x: np.concatenate(x, axis=0).mean(0) + +for k in ["global_orient", "body_pose", "left_hand_pose", "right_hand_pose", "jaw_pose", "full_pose"]: + aggregate_function[k] = aggregate_rotmats + +def merge(output_dir, gender): + output_dir = Path(output_dir) + assert output_dir.exists() + assert output_dir.is_dir() + + # get list of all pkl files in output_dir with fixed length numeral names + pkl_files = [f for f in output_dir.glob("*.pkl") if f.stem != "merged"] + pkl_files = [f for f in sorted(pkl_files, key=lambda x: int(x.stem))] + assert "merged.pkl" not in [f.name for f in pkl_files] + + merged = {} + # iterate over keys and put all values in lists + keys = set(KEYS) + for k in keys: + merged[k] = [] + for pkl_file in pkl_files: + with open(pkl_file, "rb") as f: + data = pickle.load(f) + for k in keys: + if k in data: + merged[k].append(data[k]) + b = np.concatenate(merged["betas"], axis=0) + print("betas:") + for mu, sigma in zip(b.mean(0), b.std(0)): + print(" {:.3f} +/- {:.3f}".format(mu, sigma)) + + # aggregate all values + for k in keys: + merged[k] = aggregate_function[k](merged[k]) + + # add gender + + poses = merged["full_pose"] + trans = merged["transl"] + if gender == "female": + gender = np.zeros([poses.shape[0], 1]) + elif gender == "male": + gender = np.ones([poses.shape[0], 1]) + else: + gender = np.ones([poses.shape[0], 1]) * 2 + + merged = np.concatenate([poses, trans, gender], axis=1) + + return merged + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Merge output of transfer_model script') + parser.add_argument('output_dir', type=str, help='output directory of transfer_model script') + parser.add_argument('--gender', type=str, choices=['male', 'female', 'neutral'], help='gender of actor in motion sequence') + args = parser.parse_args() + merge(args.output_dir, args.gender) diff --git a/SMPLX/transfer_model/optimizers/__init__.py b/SMPLX/transfer_model/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6adf3793bf0c313677506966ef400b29b1da2c44 --- /dev/null +++ b/SMPLX/transfer_model/optimizers/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from .optim_factory import build_optimizer +from .minimize import minimize diff --git a/SMPLX/transfer_model/optimizers/__pycache__/__init__.cpython-38.pyc b/SMPLX/transfer_model/optimizers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c3c55395f2922567aab5392da95975d558d5d64 Binary files /dev/null and b/SMPLX/transfer_model/optimizers/__pycache__/__init__.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/optimizers/__pycache__/__init__.cpython-39.pyc b/SMPLX/transfer_model/optimizers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5311e260f606aff0ba209ea5c3124eaba916c02b Binary files /dev/null and b/SMPLX/transfer_model/optimizers/__pycache__/__init__.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/optimizers/__pycache__/minimize.cpython-38.pyc b/SMPLX/transfer_model/optimizers/__pycache__/minimize.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6bca393aaa9709de0dfe0534c89a1a1971c6d79 Binary files /dev/null and b/SMPLX/transfer_model/optimizers/__pycache__/minimize.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/optimizers/__pycache__/minimize.cpython-39.pyc b/SMPLX/transfer_model/optimizers/__pycache__/minimize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a24061dcc61fa062ef021cb6e0dcbe35797049ef Binary files /dev/null and b/SMPLX/transfer_model/optimizers/__pycache__/minimize.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/optimizers/__pycache__/optim_factory.cpython-38.pyc b/SMPLX/transfer_model/optimizers/__pycache__/optim_factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd43dc15026ef4dc3ff011953027f288222803bd Binary files /dev/null and b/SMPLX/transfer_model/optimizers/__pycache__/optim_factory.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/optimizers/__pycache__/optim_factory.cpython-39.pyc b/SMPLX/transfer_model/optimizers/__pycache__/optim_factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667f11ca8d829e1ec5d7460a87972b147f1e13cc Binary files /dev/null and b/SMPLX/transfer_model/optimizers/__pycache__/optim_factory.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/optimizers/minimize.py b/SMPLX/transfer_model/optimizers/minimize.py new file mode 100644 index 0000000000000000000000000000000000000000..a37b1969dd8c7c01b69df59e7122f254d3544fc8 --- /dev/null +++ b/SMPLX/transfer_model/optimizers/minimize.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from typing import List, Union, Callable, Optional, Dict +import torch +from loguru import logger +from tqdm import tqdm + +from SMPLX.transfer_model.utils import ( + from_torch, Tensor, Array, rel_change) + + +def minimize( + optimizer: torch.optim, + closure, + params: List[Tensor], + summary_closure: Optional[Callable[[], Dict[str, float]]] = None, + maxiters=100, + ftol=-1.0, + gtol=1e-9, + interactive=True, + summary_steps=10, + **kwargs +): + ''' Helper function for running an optimization process + Args: + - optimizer: The PyTorch optimizer object + - closure: The function used to calculate the gradients + - params: a list containing the parameters that will be optimized + Keyword arguments: + - maxiters (100): The maximum number of iterations for the + optimizer + - ftol: The tolerance for the relative change in the loss + function. + If it is lower than this value, then the process stops + - gtol: The tolerance for the maximum change in the gradient. + If the maximum absolute values of the all gradient tensors + are less than this, then the process will stop. + ''' + prev_loss = None + for n in tqdm(range(maxiters), desc='Fitting iterations'): + loss = optimizer.step(closure) + + if n > 0 and prev_loss is not None and ftol > 0: + loss_rel_change = rel_change(prev_loss, loss.item()) + + if loss_rel_change <= ftol: + prev_loss = loss.item() + break + + if (all([var.grad.view(-1).abs().max().item() < gtol + for var in params if var.grad is not None]) and gtol > 0): + prev_loss = loss.item() + break + + if interactive and n % summary_steps == 0: + logger.info(f'[{n:05d}] Loss: {loss.item():.4f}') + if summary_closure is not None: + summaries = summary_closure() + for key, val in summaries.items(): + logger.info(f'[{n:05d}] {key}: {val:.4f}') + + prev_loss = loss.item() + + # Save the final step + if interactive: + logger.info(f'[{n + 1:05d}] Loss: {loss.item():.4f}') + if summary_closure is not None: + summaries = summary_closure() + for key, val in summaries.items(): + logger.info(f'[{n + 1:05d}] {key}: {val:.4f}') + + return prev_loss diff --git a/SMPLX/transfer_model/optimizers/optim_factory.py b/SMPLX/transfer_model/optimizers/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..43b161e776df9bd1303d59ee408bb20cdf5721a0 --- /dev/null +++ b/SMPLX/transfer_model/optimizers/optim_factory.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import sys + +from typing import NewType, List, Dict + +import torch +import torch.optim as optim +from loguru import logger +from torchtrustncg import TrustRegion + +Tensor = NewType('Tensor', torch.Tensor) + + +def build_optimizer(parameters: List[Tensor], + optim_cfg: Dict + ) -> Dict: + ''' Creates the optimizer + ''' + optim_type = optim_cfg.get('type', 'sgd') + logger.info(f'Building: {optim_type.title()}') + + num_params = len(parameters) + parameters = list(filter(lambda x: x.requires_grad, parameters)) + if num_params != len(parameters): + logger.info(f'Some parameters have requires_grad off') + + if optim_type == 'adam': + optimizer = optim.Adam(parameters, **optim_cfg.get('adam', {})) + create_graph = False + elif optim_type == 'lbfgs' or optim_type == 'lbfgsls': + optimizer = optim.LBFGS(parameters, **optim_cfg.get('lbfgs', {})) + create_graph = False + elif optim_type == 'trust_ncg' or optim_type == 'trust-ncg': + optimizer = TrustRegion( + parameters, **optim_cfg.get('trust_ncg', {})) + create_graph = True + elif optim_type == 'rmsprop': + optimizer = optim.RMSprop(parameters, **optim_cfg.get('rmsprop', {})) + create_graph = False + elif optim_type == 'sgd': + optimizer = optim.SGD(parameters, **optim_cfg.get('sgd', {})) + create_graph = False + else: + raise ValueError(f'Optimizer {optim_type} not supported!') + return {'optimizer': optimizer, 'create_graph': create_graph} + + +def build_scheduler(optimizer, sched_type='exp', + lr_lambda=0.1, **kwargs): + if lr_lambda <= 0.0: + return None + + if sched_type == 'exp': + return optim.lr_scheduler.ExponentialLR(optimizer, lr_lambda) + else: + raise ValueError('Unknown learning rate' + + ' scheduler: '.format(sched_type)) diff --git a/SMPLX/transfer_model/requirements.txt b/SMPLX/transfer_model/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..068b6580684d8102333a971b6f7032abc5766b73 --- /dev/null +++ b/SMPLX/transfer_model/requirements.txt @@ -0,0 +1,10 @@ +numpy>=1.16.2 +torch>=1.0.1.post2 +dataclasses>=0.6 +pyrender>=0.1.23 +shapely +trimesh>=2.37.6 +open3d +smplx +omegaconf +loguru diff --git a/SMPLX/transfer_model/transfer_model.py b/SMPLX/transfer_model/transfer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ce88da1f764f4de9a37b07cfd34fe9125e68af87 --- /dev/null +++ b/SMPLX/transfer_model/transfer_model.py @@ -0,0 +1,416 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from typing import Optional, Dict, Callable +import sys +import numpy as np +import torch +import torch.nn as nn + +from tqdm import tqdm + +from loguru import logger +from SMPLX.transfer_model.utils import get_vertices_per_edge + +from SMPLX.transfer_model.optimizers import build_optimizer, minimize +from SMPLX.transfer_model.utils import ( + Tensor, batch_rodrigues, apply_deformation_transfer) +from SMPLX.transfer_model.losses import build_loss + + +def summary_closure(gt_vertices, var_dict, body_model, mask_ids=None): + param_dict = {} + for key, var in var_dict.items(): + # Decode the axis-angles + if 'pose' in key or 'orient' in key: + param_dict[key] = batch_rodrigues( + var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) + else: + # Simply pass the variable + param_dict[key] = var + body_model_output = body_model( + return_full_pose=True, get_skin=True, **param_dict) + est_vertices = body_model_output.vertices + if mask_ids is not None: + est_vertices = est_vertices[:, mask_ids] + gt_vertices = gt_vertices[:, mask_ids] + + v2v = (est_vertices - gt_vertices).pow(2).sum(dim=-1).sqrt().mean() + return { + 'Vertex-to-Vertex': v2v * 1000} + + +def build_model_forward_closure( + body_model: nn.Module, + var_dict: Dict[str, Tensor], + per_part: bool = True, + part_key: Optional[str] = None, + jidx: Optional[int] = None, + part: Optional[Tensor] = None +) -> Callable: + if per_part: + cond = part is not None and part_key is not None and jidx is not None + assert cond, ( + 'When per-part is True, "part", "part_key", "jidx" must not be' + ' None.' + ) + + def model_forward(): + param_dict = {} + for key, var in var_dict.items(): + if part_key == key: + param_dict[key] = batch_rodrigues( + var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) + param_dict[key][:, jidx] = batch_rodrigues( + part.reshape(-1, 3)).reshape(-1, 3, 3) + else: + # Decode the axis-angles + if 'pose' in key or 'orient' in key: + param_dict[key] = batch_rodrigues( + var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) + else: + # Simply pass the variable + param_dict[key] = var + + return body_model( + return_full_pose=True, get_skin=True, **param_dict) + else: + def model_forward(): + param_dict = {} + for key, var in var_dict.items(): + # Decode the axis-angles + if 'pose' in key or 'orient' in key: + param_dict[key] = batch_rodrigues( + var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) + else: + # Simply pass the variable + param_dict[key] = var + + return body_model(return_full_pose=True, get_skin=True, + **param_dict) + return model_forward + + +def build_edge_closure( + body_model: nn.Module, + var_dict: Dict[str, Tensor], + edge_loss: nn.Module, + optimizer_dict, + gt_vertices: Tensor, + per_part: bool = True, + part_key: Optional[str] = None, + jidx: Optional[int] = None, + part: Optional[Tensor] = None +) -> Callable: + ''' Builds the closure for the edge objective + ''' + optimizer = optimizer_dict['optimizer'] + create_graph = optimizer_dict['create_graph'] + + if per_part: + params_to_opt = [part] + else: + params_to_opt = [p for key, p in var_dict.items() if 'pose' in key] + + model_forward = build_model_forward_closure( + body_model, var_dict, per_part=per_part, part_key=part_key, + jidx=jidx, part=part) + + def closure(backward=True): + if backward: + optimizer.zero_grad() + + body_model_output = model_forward() + + est_vertices = body_model_output.vertices + + loss = edge_loss(est_vertices, gt_vertices) + if backward: + if create_graph: + # Use this instead of .backward to avoid GPU memory leaks + grads = torch.autograd.grad( + loss, params_to_opt, create_graph=True) + torch.autograd.backward( + params_to_opt, grads, create_graph=True) + else: + loss.backward() + + return loss + return closure + + +def build_vertex_closure( + body_model: nn.Module, + var_dict: Dict[str, Tensor], + optimizer_dict, + gt_vertices: Tensor, + vertex_loss: nn.Module, + mask_ids=None, + per_part: bool = True, + part_key: Optional[str] = None, + jidx: Optional[int] = None, + part: Optional[Tensor] = None, + params_to_opt: Optional[Tensor] = None, +) -> Callable: + ''' Builds the closure for the vertex objective + ''' + optimizer = optimizer_dict['optimizer'] + create_graph = optimizer_dict['create_graph'] + + model_forward = build_model_forward_closure( + body_model, var_dict, per_part=per_part, part_key=part_key, + jidx=jidx, part=part) + + if params_to_opt is None: + params_to_opt = [p for key, p in var_dict.items()] + + def closure(backward=True): + if backward: + optimizer.zero_grad() + + body_model_output = model_forward() + est_vertices = body_model_output.vertices + + loss = vertex_loss( + est_vertices[:, mask_ids] if mask_ids is not None else + est_vertices, + gt_vertices[:, mask_ids] if mask_ids is not None else gt_vertices) + if backward: + if create_graph: + # Use this instead of .backward to avoid GPU memory leaks + grads = torch.autograd.grad( + loss, params_to_opt, create_graph=True) + torch.autograd.backward( + params_to_opt, grads, create_graph=True) + else: + loss.backward() + + return loss + return closure + + +def get_variables( + batch_size: int, + body_model: nn.Module, + dtype: torch.dtype = torch.float32 +) -> Dict[str, Tensor]: + var_dict = {} + + device = next(body_model.buffers()).device + + if (body_model.name() == 'SMPL' or body_model.name() == 'SMPL+H' or + body_model.name() == 'SMPL-X'): + var_dict.update({ + 'transl': torch.zeros( + [batch_size, 3], device=device, dtype=dtype), + 'global_orient': torch.zeros( + [batch_size, 1, 3], device=device, dtype=dtype), + 'body_pose': torch.zeros( + [batch_size, body_model.NUM_BODY_JOINTS, 3], + device=device, dtype=dtype), + 'betas': torch.zeros([batch_size, body_model.num_betas], + dtype=dtype, device=device), + }) + + if body_model.name() == 'SMPL+H' or body_model.name() == 'SMPL-X': + var_dict.update( + left_hand_pose=torch.zeros( + [batch_size, body_model.NUM_HAND_JOINTS, 3], device=device, + dtype=dtype), + right_hand_pose=torch.zeros( + [batch_size, body_model.NUM_HAND_JOINTS, 3], device=device, + dtype=dtype), + ) + + if body_model.name() == 'SMPL-X': + var_dict.update( + jaw_pose=torch.zeros([batch_size, 1, 3], + device=device, dtype=dtype), + leye_pose=torch.zeros([batch_size, 1, 3], + device=device, dtype=dtype), + reye_pose=torch.zeros([batch_size, 1, 3], + device=device, dtype=dtype), + expression=torch.zeros( + [batch_size, body_model.num_expression_coeffs], + device=device, dtype=dtype), + ) + + # Toggle gradients to True + for key, val in var_dict.items(): + val.requires_grad_(True) + + return var_dict + + +def run_fitting( + # exp_cfg, + batch: Dict[str, Tensor], + body_model: nn.Module, + def_matrix: Tensor, + mask_ids +) -> Dict[str, Tensor]: + ''' Runs fitting + ''' + vertices = batch['vertices'] + faces = batch['faces'] + + batch_size = len(vertices) + dtype, device = vertices.dtype, vertices.device + # summary_steps = exp_cfg.get('summary_steps') + # interactive = exp_cfg.get('interactive') + + summary_steps = 100 + interactive = True + + # Get the parameters from the model + var_dict = get_variables(batch_size, body_model) + + # Build the optimizer object for the current batch + # optim_cfg = exp_cfg.get('optim', {}) + + optim_cfg = {'type': 'trust-ncg', 'lr': 1.0, 'gtol': 1e-06, 'ftol': -1.0, 'maxiters': 100, 'lbfgs': {'line_search_fn': 'strong_wolfe', 'max_iter': 50}, 'sgd': {'momentum': 0.9, 'nesterov': True}, 'adam': {'betas': [0.9, 0.999], 'eps': 1e-08, 'amsgrad': False}, 'trust_ncg': {'max_trust_radius': 1000.0, 'initial_trust_radius': 0.05, 'eta': 0.15, 'gtol': 1e-05}} + + def_vertices = apply_deformation_transfer(def_matrix, vertices, faces) + + if mask_ids is None: + f_sel = np.ones_like(body_model.faces[:, 0], dtype=np.bool_) + else: + f_per_v = [[] for _ in range(body_model.get_num_verts())] + [f_per_v[vv].append(iff) for iff, ff in enumerate(body_model.faces) + for vv in ff] + f_sel = list(set(tuple(sum([f_per_v[vv] for vv in mask_ids], [])))) + vpe = get_vertices_per_edge( + body_model.v_template.detach().cpu().numpy(), body_model.faces[f_sel]) + + def log_closure(): + return summary_closure(def_vertices, var_dict, body_model, + mask_ids=mask_ids) + + # edge_fitting_cfg = exp_cfg.get('edge_fitting', {}) + edge_fitting_cfg = {'per_part': False, 'reduction': 'mean'} + + edge_loss = build_loss(type='vertex-edge', gt_edges=vpe, est_edges=vpe, + **edge_fitting_cfg) + edge_loss = edge_loss.to(device=device) + + # vertex_fitting_cfg = exp_cfg.get('vertex_fitting', {}) + vertex_fitting_cfg = {} + + vertex_loss = build_loss(**vertex_fitting_cfg) + vertex_loss = vertex_loss.to(device=device) + + per_part = edge_fitting_cfg.get('per_part', True) + logger.info(f'Per-part: {per_part}') + # Optimize edge-based loss to initialize pose + if per_part: + for key, var in tqdm(var_dict.items(), desc='Parts'): + if 'pose' not in key: + continue + + for jidx in tqdm(range(var.shape[1]), desc='Joints'): + part = torch.zeros( + [batch_size, 3], dtype=dtype, device=device, + requires_grad=True) + # Build the optimizer for the current part + optimizer_dict = build_optimizer([part], optim_cfg) + closure = build_edge_closure( + body_model, var_dict, edge_loss, optimizer_dict, + def_vertices, per_part=per_part, part_key=key, jidx=jidx, + part=part) + + minimize(optimizer_dict['optimizer'], closure, + params=[part], + summary_closure=log_closure, + summary_steps=summary_steps, + interactive=interactive, + **optim_cfg) + with torch.no_grad(): + var[:, jidx] = part + else: + optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg) + closure = build_edge_closure( + body_model, var_dict, edge_loss, optimizer_dict, + def_vertices, per_part=per_part) + + minimize(optimizer_dict['optimizer'], closure, + params=var_dict.values(), + summary_closure=log_closure, + summary_steps=summary_steps, + interactive=interactive, + **optim_cfg) + + if 'translation' in var_dict: + optimizer_dict = build_optimizer([var_dict['translation']], optim_cfg) + closure = build_vertex_closure( + body_model, var_dict, + optimizer_dict, + def_vertices, + vertex_loss=vertex_loss, + mask_ids=mask_ids, + per_part=False, + params_to_opt=[var_dict['translation']], + ) + # Optimize translation + minimize(optimizer_dict['optimizer'], + closure, + params=[var_dict['translation']], + summary_closure=log_closure, + summary_steps=summary_steps, + interactive=interactive, + **optim_cfg) + + # Optimize all model parameters with vertex-based loss + optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg) + closure = build_vertex_closure( + body_model, var_dict, + optimizer_dict, + def_vertices, + vertex_loss=vertex_loss, + per_part=False, + mask_ids=mask_ids) + minimize(optimizer_dict['optimizer'], closure, + params=list(var_dict.values()), + summary_closure=log_closure, + summary_steps=summary_steps, + interactive=interactive, + **optim_cfg) + + param_dict = {} + for key, var in var_dict.items(): + # Decode the axis-angles + if 'pose' in key or 'orient' in key: + param_dict[key] = batch_rodrigues( + var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) + else: + # Simply pass the variable + param_dict[key] = var + + body_model_output = body_model( + return_full_pose=True, get_skin=True, **param_dict) + + keys = ["vertices", "joints", "betas", "global_orient", "body_pose", "left_hand_pose", "right_hand_pose", "full_pose"] + for key in keys: + var_dict[key] = getattr(body_model_output, key) + + var_dict['faces'] = body_model.faces + + for key in var_dict.keys(): + try: + var_dict[key] = var_dict[key].detach().cpu().numpy() + except: + pass + + return var_dict diff --git a/SMPLX/transfer_model/utils/__init__.py b/SMPLX/transfer_model/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78c1bfe456318eb1b1b7e8c1cbd3c7c566e58751 --- /dev/null +++ b/SMPLX/transfer_model/utils/__init__.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from .np_utils import to_np, rel_change +from .torch_utils import from_torch +from .timer import Timer, timer_decorator +from .typing import * +from .pose_utils import batch_rodrigues, batch_rot2aa +from .metrics import v2v +from .def_transfer import read_deformation_transfer, apply_deformation_transfer +from .mesh_utils import get_vertices_per_edge +from .o3d_utils import np_mesh_to_o3d diff --git a/SMPLX/transfer_model/utils/__pycache__/__init__.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6574593688be9adbfa3fad872ee9622ee42d839 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/__init__.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e06b89a583911642618e57bab0506bb226a10788 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/def_transfer.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/def_transfer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5269eade942b9abc038e37c97439922d9842e812 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/def_transfer.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/def_transfer.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/def_transfer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b93e78c5f93d8650ba5fa28d8954eca8e49fb2c4 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/def_transfer.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/mesh_utils.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/mesh_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06cac25a81f928ec29f90a9e4b93de158db39186 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/mesh_utils.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/mesh_utils.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/mesh_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2414b483df5d20f079c24d1488a84ca92cdf1817 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/mesh_utils.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/metrics.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6271001434c239d5bee1c8d93b09c5b34be391b2 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/metrics.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/metrics.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff83a619bf87f4db3448fcff2fa0ebc0aa2af98b Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/metrics.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/np_utils.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/np_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b78627a99a1347627c6dc4e86b97579aa45c902c Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/np_utils.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/np_utils.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/np_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e6fc925a0d362c8320e3e0336d59b4380cfb5ff Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/np_utils.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/o3d_utils.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/o3d_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f066f4da9f1fab32bb089e7921cc165b1fe8e8e Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/o3d_utils.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/o3d_utils.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/o3d_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afbabaf1af5da13b856c12ee732a5ba3c982ab6d Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/o3d_utils.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/pose_utils.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/pose_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a06cce55aff5d1f4c62977e411022e603dcc713c Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/pose_utils.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/pose_utils.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/pose_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e654da8b2962cda8e5256683e180f7926ca07cdc Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/pose_utils.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/timer.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/timer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb2e89a72148e15b08152c7606fcee303f947a35 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/timer.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/timer.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/timer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c18cb5e6274128d8535d78a8119fb4b0e31fe674 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/timer.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/torch_utils.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/torch_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a90854b86d794288be2d38ecfa82bcfaa90396c3 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/torch_utils.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/torch_utils.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/torch_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3963b5956d954de8496d2f9032c65ae658330dea Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/torch_utils.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/typing.cpython-38.pyc b/SMPLX/transfer_model/utils/__pycache__/typing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae0d06d2b449e8d8aad41a9ac069e5d5a3082ed5 Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/typing.cpython-38.pyc differ diff --git a/SMPLX/transfer_model/utils/__pycache__/typing.cpython-39.pyc b/SMPLX/transfer_model/utils/__pycache__/typing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cd6bf742882b2b8c096db0702ade404084863ec Binary files /dev/null and b/SMPLX/transfer_model/utils/__pycache__/typing.cpython-39.pyc differ diff --git a/SMPLX/transfer_model/utils/def_transfer.py b/SMPLX/transfer_model/utils/def_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..64a7462b1535901dda30f3dc35caed6295310507 --- /dev/null +++ b/SMPLX/transfer_model/utils/def_transfer.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import os +import os.path as osp +import pickle + +import numpy as np +import torch +from loguru import logger + +from SMPLX.transfer_model.utils.typing import Tensor + + +def read_deformation_transfer( + deformation_transfer_path: str, + device=None, + use_normal: bool = False, +) -> Tensor: + ''' Reads a deformation transfer + ''' + if device is None: + device = torch.device('cpu') + assert osp.exists(deformation_transfer_path), ( + 'Deformation transfer path does not exist:' + f' {deformation_transfer_path}') + logger.info( + f'Loading deformation transfer from: {deformation_transfer_path}') + # Read the deformation transfer matrix + with open(deformation_transfer_path, 'rb') as f: + def_transfer_setup = pickle.load(f, encoding='latin1') + if 'mtx' in def_transfer_setup: + def_matrix = def_transfer_setup['mtx'] + if hasattr(def_matrix, 'todense'): + def_matrix = def_matrix.todense() + def_matrix = np.array(def_matrix, dtype=np.float32) + if not use_normal: + num_verts = def_matrix.shape[1] // 2 + def_matrix = def_matrix[:, :num_verts] + elif 'matrix' in def_transfer_setup: + def_matrix = def_transfer_setup['matrix'] + else: + valid_keys = ['mtx', 'matrix'] + raise KeyError(f'Deformation transfer setup must contain {valid_keys}') + + def_matrix = torch.tensor(def_matrix, device=device, dtype=torch.float32) + return def_matrix + + +def apply_deformation_transfer( + def_matrix: Tensor, + vertices: Tensor, + faces: Tensor, + use_normals=False +) -> Tensor: + ''' Applies the deformation transfer on the given meshes + ''' + if use_normals: + raise NotImplementedError + else: + def_vertices = torch.einsum('mn,bni->bmi', [def_matrix, vertices]) + return def_vertices diff --git a/SMPLX/transfer_model/utils/mesh_utils.py b/SMPLX/transfer_model/utils/mesh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d681b2f2ea235d01e1df0db64618d6f18683989 --- /dev/null +++ b/SMPLX/transfer_model/utils/mesh_utils.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Code from Chumpy and OpenDR. Placed here to avoid chumpy dependency +# The original code can be found in https://github.com/MPI-IS/mesh +import numpy as np +import scipy.sparse as sp + + +def row(A): + return A.reshape((1, -1)) + + +def col(A): + return A.reshape((-1, 1)) + + +def get_vert_connectivity(mesh_v, mesh_f): + """Returns a sparse matrix (of size #verts x #verts) where each nonzero + element indicates a neighborhood relation. For example, if there is a + nonzero element in position (15,12), that means vertex 15 is connected + by an edge to vertex 12.""" + + vpv = sp.csc_matrix((len(mesh_v), len(mesh_v))) + + # for each column in the faces... + for i in range(3): + IS = mesh_f[:, i] + JS = mesh_f[:, (i + 1) % 3] + data = np.ones(len(IS)) + ij = np.vstack((row(IS.flatten()), row(JS.flatten()))) + mtx = sp.csc_matrix((data, ij), shape=vpv.shape) + vpv = vpv + mtx + mtx.T + + return vpv + + +def get_vertices_per_edge(mesh_v, mesh_f): + """Returns an Ex2 array of adjacencies between vertices, where + each element in the array is a vertex index. Each edge is included + only once. If output of get_faces_per_edge is provided, this is used to + avoid call to get_vert_connectivity()""" + + vc = sp.coo_matrix(get_vert_connectivity(mesh_v, mesh_f)) + result = np.hstack((col(vc.row), col(vc.col))) + result = result[result[:, 0] < result[:, 1]] # for uniqueness + + return result diff --git a/SMPLX/transfer_model/utils/metrics.py b/SMPLX/transfer_model/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..d7f8209423ce342a1137a1c4bfe22709cf0f0357 --- /dev/null +++ b/SMPLX/transfer_model/utils/metrics.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import numpy as np +import torch + + +def v2v(x, y): + if torch.is_tensor(x): + return (x - y).pow(2).sum(dim=-1).sqrt().mean() + else: + return np.sqrt(np.power(x - y, 2)).sum(axis=-1).mean() diff --git a/SMPLX/transfer_model/utils/np_utils.py b/SMPLX/transfer_model/utils/np_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d0cfa56dfd7cda764b127048683617ad437b48d2 --- /dev/null +++ b/SMPLX/transfer_model/utils/np_utils.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np + + +def rel_change(prev_val, curr_val): + return (prev_val - curr_val) / max([np.abs(prev_val), np.abs(curr_val), 1]) + + +def max_grad_change(grad_arr): + return grad_arr.abs().max() + + +def to_np(array, dtype=np.float32): + if hasattr(array, 'todense'): + array = array.todense() + return np.array(array, dtype=dtype) diff --git a/SMPLX/transfer_model/utils/o3d_utils.py b/SMPLX/transfer_model/utils/o3d_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..758f85614fa2083c8c0febc513d35413d66c2c20 --- /dev/null +++ b/SMPLX/transfer_model/utils/o3d_utils.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import open3d as o3d +import torch + +Vector3d = o3d.utility.Vector3dVector +Vector3i = o3d.utility.Vector3iVector + +Mesh = o3d.geometry.TriangleMesh + + +def np_mesh_to_o3d(vertices, faces): + if torch.is_tensor(vertices): + vertices = vertices.detach().cpu().numpy() + if torch.is_tensor(faces): + faces = faces.detach().cpu().numpy() + mesh = Mesh() + mesh.vertices = Vector3d(vertices) + mesh.triangles = Vector3i(faces) + return mesh diff --git a/SMPLX/transfer_model/utils/pose_utils.py b/SMPLX/transfer_model/utils/pose_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..778167c230eea6407e8e88bcb4441cddef591406 --- /dev/null +++ b/SMPLX/transfer_model/utils/pose_utils.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import sys +from typing import NewType, List, Dict, Optional +import os +import os.path as osp + +import pickle + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from omegaconf import OmegaConf +from loguru import logger + +from SMPLX.transfer_model.utils.typing import Tensor + + +def rotation_matrix_to_cont_repr(x: Tensor) -> Tensor: + assert len(x.shape) == 3, ( + f'Expects an array of size Bx3x3, but received {x.shape}') + return x[:, :3, :2] + + +def cont_repr_to_rotation_matrix( + x: Tensor +) -> Tensor: + ''' Converts tensor in continous representation to rotation matrices + ''' + batch_size = x.shape[0] + reshaped_input = x.view(-1, 3, 2) + + # Normalize the first vector + b1 = F.normalize(reshaped_input[:, :, 0].clone(), dim=1) + + dot_prod = torch.sum( + b1 * reshaped_input[:, :, 1].clone(), dim=1, keepdim=True) + # Compute the second vector by finding the orthogonal complement to it + b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=1) + # Finish building the basis by taking the cross product + b3 = torch.cross(b1, b2, dim=1) + rot_mats = torch.stack([b1, b2, b3], dim=-1) + + return rot_mats.view(batch_size, -1, 3, 3) + + +def batch_rodrigues( + rot_vecs: Tensor, + epsilon: float = 1e-8 +) -> Tensor: + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + assert len(rot_vecs.shape) == 2, ( + f'Expects an array of size Bx3, but received {rot_vecs.shape}') + + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + dtype = rot_vecs.dtype + + angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True, p=2) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def batch_rot2aa( + Rs: Tensor, epsilon: float = 1e-7 +) -> Tensor: + """ + Rs is B x 3 x 3 + void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis, + double& out_theta) + { + double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1); + c = cMathUtil::Clamp(c, -1.0, 1.0); + + out_theta = std::acos(c); + + if (std::abs(out_theta) < 0.00001) + { + out_axis = tVector(0, 0, 1, 0); + } + else + { + double m21 = mat(2, 1) - mat(1, 2); + double m02 = mat(0, 2) - mat(2, 0); + double m10 = mat(1, 0) - mat(0, 1); + double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10); + out_axis[0] = m21 / denom; + out_axis[1] = m02 / denom; + out_axis[2] = m10 / denom; + out_axis[3] = 0; + } + } + """ + + cos = 0.5 * (torch.einsum('bii->b', [Rs]) - 1) + cos = torch.clamp(cos, -1 + epsilon, 1 - epsilon) + + theta = torch.acos(cos) + + m21 = Rs[:, 2, 1] - Rs[:, 1, 2] + m02 = Rs[:, 0, 2] - Rs[:, 2, 0] + m10 = Rs[:, 1, 0] - Rs[:, 0, 1] + denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10 + epsilon) + + axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom) + axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom) + axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom) + + return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1) diff --git a/SMPLX/transfer_model/utils/timer.py b/SMPLX/transfer_model/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..027de5b83f67e241ac5d018e6732b9d27f7a2c91 --- /dev/null +++ b/SMPLX/transfer_model/utils/timer.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import time +import numpy as np +import torch + +from loguru import logger + + +class Timer(object): + def __init__(self, name='', sync=False): + super(Timer, self).__init__() + self.elapsed = [] + self.name = name + self.sync = sync + + def __enter__(self): + if self.sync: + torch.cuda.synchronize() + self.start = time.perf_counter() + + def __exit__(self, type, value, traceback): + if self.sync: + torch.cuda.synchronize() + elapsed = time.perf_counter() - self.start + self.elapsed.append(elapsed) + logger.info(f'[{self.name}]: {np.mean(self.elapsed):.3f}') + + +def timer_decorator(sync=False, name=''): + def wrapper(method): + elapsed = [] + + def timed(*args, **kw): + if sync: + torch.cuda.synchronize() + ts = time.perf_counter() + result = method(*args, **kw) + if sync: + torch.cuda.synchronize() + te = time.perf_counter() + elapsed.append(te - ts) + logger.info(f'[{name}]: {np.mean(elapsed):.3f}') + return result + return timed + return wrapper diff --git a/SMPLX/transfer_model/utils/torch_utils.py b/SMPLX/transfer_model/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa820d4e429ae611c384842980b66d63b07aacc --- /dev/null +++ b/SMPLX/transfer_model/utils/torch_utils.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +import numpy as np +import torch + + +def from_torch(x, dtype=np.float32): + if torch.is_tensor(x): + x = x.detach().cpu().numpy() + return x.astype(dtype) diff --git a/SMPLX/transfer_model/utils/typing.py b/SMPLX/transfer_model/utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ba3109d49f1e6c2496772efa4f422e2e06dca5 --- /dev/null +++ b/SMPLX/transfer_model/utils/typing.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de + +from typing import NewType, List, Union +import numpy as np +import torch + +__all__ = [ + 'Tensor', + 'Array', +] + +Tensor = NewType('Tensor', torch.Tensor) +Array = NewType('Array', np.ndarray) diff --git a/SMPLX/transfer_model/view_pkl.py b/SMPLX/transfer_model/view_pkl.py new file mode 100644 index 0000000000000000000000000000000000000000..0cafdef75f4056de3b42b5a3bad3f686264d1690 --- /dev/null +++ b/SMPLX/transfer_model/view_pkl.py @@ -0,0 +1,140 @@ +import os.path as osp +import argparse + +import numpy as np +import torch + +import pyrender +import trimesh + +import smplx + +from tqdm.auto import tqdm, trange + +from pathlib import Path + +def main(model_folder, + motion_file, + model_type='smplx', + ext='npz', + gender='neutral', + plot_joints=False, + num_betas=10, + sample_expression=True, + num_expression_coeffs=10, + use_face_contour=False): + + # open motion file + motion = np.load(motion_file, allow_pickle=True) + _motion = {} + for k,v in motion.items(): + if isinstance(v, np.ndarray): + print(k, motion[k].shape, motion[k].dtype) + if motion[k].dtype in ("/smpl_models/smpl/`, eventually, the `/smpl_models` folder should have the following structure: + ``` + smpl_models + └-- smpl + └-- SMPL_FEMALE.pkl + └-- SMPL_MALE.pkl + └-- SMPL_NEUTRAL.pkl + ``` + +## Demo +### Demo for sequences +python fit_seq.py --files test_motion2.npy + +The results will locate in ./demo/demo_results/ + +## Citation +If you find this project useful for your research, please consider citing: +``` +@article{zuo2021sparsefusion, + title={Sparsefusion: Dynamic human avatar modeling from sparse rgbd images}, + author={Zuo, Xinxin and Wang, Sen and Zheng, Jiangbin and Yu, Weiwei and Gong, Minglun and Yang, Ruigang and Cheng, Li}, + journal={IEEE Transactions on Multimedia}, + volume={23}, + pages={1617--1629}, + year={2021} +} +``` + +## References +We indicate if a function or script is borrowed externally inside each file. Here are some great resources we +benefit: + +- Shape/Pose prior and some functions are borrowed from [VIBE](https://github.com/mkocabas/VIBE). +- SMPL models and layer is from [SMPL-X model](https://github.com/vchoutas/smplx). +- Some functions are borrowed from [HMR-pytorch](https://github.com/MandyMo/pytorch_HMR). diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/environment.yaml b/SMPLX/visualize_joint2smpl/joints2smpl/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28d0498a9d944cd60bffa6e38c8063d6b6cee7f3 --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/environment.yaml @@ -0,0 +1,30 @@ +name: fit3d +channels: + - conda-forge + - pytorch + - defaults + - pytorch3d + - open3d-admin + - anaconda +dependencies: + - pip=21.1.3 + - numpy=1.20.3 + - numpy-base=1.20.3 + - matplotlib=3.4.2 + - matplotlib-base=3.4.2 + - pandas=1.3.1 + - python=3.7.6 + - pytorch=1.7.1 + - tensorboardx=2.2 + - cudatoolkit=10.2.89 + - torchvision=0.8.2 + - einops=0.3.0 + - pytorch3d=0.4.0 + - tqdm=4.61.2 + - trimesh=3.9.24 + - joblib=1.0.1 + - open3d=0.13.0 + - pip: + - h5py==2.9.0 + - chumpy==0.70 + - smplx==0.1.28 diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/fit_seq.py b/SMPLX/visualize_joint2smpl/joints2smpl/fit_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..5191ff4de688bc49770c3d8234ebeb6e25b55817 --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/fit_seq.py @@ -0,0 +1,132 @@ +from __future__ import print_function, division +import argparse +import torch +import os,sys +from os import walk, listdir +from os.path import isfile, join +import numpy as np +import joblib +import smplx +import trimesh +import h5py +from tqdm import tqdm + +sys.path.append(os.path.join(os.path.dirname(__file__), "src")) +from smplify import SMPLify3D +import config + +# parsing argmument +parser = argparse.ArgumentParser() +parser.add_argument('--batchSize', type=int, default=1, + help='input batch size') +parser.add_argument('--num_smplify_iters', type=int, default=100, + help='num of smplify iters') +parser.add_argument('--cuda', type=bool, default=False, + help='enables cuda') +parser.add_argument('--gpu_ids', type=int, default=0, + help='choose gpu ids') +parser.add_argument('--num_joints', type=int, default=22, + help='joint number') +parser.add_argument('--joint_category', type=str, default="AMASS", + help='use correspondence') +parser.add_argument('--fix_foot', type=str, default="False", + help='fix foot or not') +parser.add_argument('--data_folder', type=str, default="./demo/demo_data/", + help='data in the folder') +parser.add_argument('--save_folder', type=str, default="./demo/demo_results/", + help='results save folder') +parser.add_argument('--files', type=str, default="test_motion.npy", + help='files use') +opt = parser.parse_args() +print(opt) + +# ---load predefined something +device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu") +print(config.SMPL_MODEL_DIR) +smplmodel = smplx.create(config.SMPL_MODEL_DIR, + model_type="smpl", gender="neutral", ext="pkl", + batch_size=opt.batchSize).to(device) + +# ## --- load the mean pose as original ---- +smpl_mean_file = config.SMPL_MEAN_FILE + +file = h5py.File(smpl_mean_file, 'r') +init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).float() +init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).float() +cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).to(device) +# +pred_pose = torch.zeros(opt.batchSize, 72).to(device) +pred_betas = torch.zeros(opt.batchSize, 10).to(device) +pred_cam_t = torch.zeros(opt.batchSize, 3).to(device) +keypoints_3d = torch.zeros(opt.batchSize, opt.num_joints, 3).to(device) + +# # #-------------initialize SMPLify +smplify = SMPLify3D(smplxmodel=smplmodel, + batch_size=opt.batchSize, + joints_category=opt.joint_category, + num_iters=opt.num_smplify_iters, + device=device) +#print("initialize SMPLify3D done!") + + +purename = os.path.splitext(opt.files)[0] +# --- load data --- +data = np.load(opt.data_folder + "/" + purename + ".npy") # [nframes, njoints, 3] + +dir_save = os.path.join(opt.save_folder, purename) +if not os.path.isdir(dir_save): + os.makedirs(dir_save, exist_ok=True) + +# run the whole seqs +num_seqs = data.shape[0] + +for idx in tqdm(range(num_seqs)): + #print(idx) + + joints3d = data[idx] #*1.2 #scale problem [check first] + keypoints_3d[0, :, :] = torch.Tensor(joints3d).to(device).float() + + if idx == 0: + pred_betas[0, :] = init_mean_shape + pred_pose[0, :] = init_mean_pose + pred_cam_t[0, :] = cam_trans_zero + else: + data_param = joblib.load(dir_save + "/" + "%04d"%(idx-1) + ".pkl") + pred_betas[0, :] = torch.from_numpy(data_param['beta']).unsqueeze(0).float() + pred_pose[0, :] = torch.from_numpy(data_param['pose']).unsqueeze(0).float() + pred_cam_t[0, :] = torch.from_numpy(data_param['cam']).unsqueeze(0).float() + + if opt.joint_category =="AMASS": + confidence_input = torch.ones(opt.num_joints) + # make sure the foot and ankle + if opt.fix_foot == True: + confidence_input[7] = 1.5 + confidence_input[8] = 1.5 + confidence_input[10] = 1.5 + confidence_input[11] = 1.5 + else: + print("Such category not settle down!") + + # ----- from initial to fitting ------- + new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \ + new_opt_cam_t, new_opt_joint_loss = smplify( + pred_pose.detach(), + pred_betas.detach(), + pred_cam_t.detach(), + keypoints_3d, + conf_3d=confidence_input.to(device), + seq_ind=idx + ) + + # # -- save the results to ply--- + outputp = smplmodel(betas=new_opt_betas, global_orient=new_opt_pose[:, :3], body_pose=new_opt_pose[:, 3:], + transl=new_opt_cam_t, return_verts=True) + mesh_p = trimesh.Trimesh(vertices=outputp.vertices.detach().cpu().numpy().squeeze(), faces=smplmodel.faces, process=False) + mesh_p.export(dir_save + "/" + "%04d"%idx + ".ply") + + # save the pkl + param = {} + param['beta'] = new_opt_betas.detach().cpu().numpy() + param['pose'] = new_opt_pose.detach().cpu().numpy() + param['cam'] = new_opt_cam_t.detach().cpu().numpy() + joblib.dump(param, dir_save + "/" + "%04d"%idx + ".pkl", compress=3) diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/SMPL_downsample_index.pkl b/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/SMPL_downsample_index.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7bb54c4f1e03340ad58b60485abaed1641d68d47 --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/SMPL_downsample_index.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5b783c1677079397ee4bc26df5c72d73b8bb393bea41fa295b951187443daec +size 3556 diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/gmm_08.pkl b/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/gmm_08.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c97a1d7ef396581e56ce74a12cc39175680ce028 --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/gmm_08.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1374908aae055a2afa01a2cd9a169bc6cfec1ceb7aa590e201a47b383060491 +size 839127 diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 b/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 new file mode 100644 index 0000000000000000000000000000000000000000..b6ecce2a748128cfde09b219ccc74307de50bbae --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac9b474c74daec0253ed084720f662059336e976850f08a4a9a3f76d06613776 +size 4848 diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/smplx_parts_segm.pkl b/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/smplx_parts_segm.pkl new file mode 100644 index 0000000000000000000000000000000000000000..77ce98631741ba3887d689077baf35422d39299d --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/smplx_parts_segm.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb69c10801205c9cfb5353fdeb1b9cc5ade53d14c265c3339421cdde8b9c91e7 +size 1323168 diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-310.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62d8541eb1e81dd12e524fdadd679fce41b8dfaa Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-310.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-311.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e84fe38f07e6144078f7893c6a52b78b38350580 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-311.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-38.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..570a24568f36c9a7966375a1131b1c27ba6afe42 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-38.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-39.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c4430719a896c3559d0cc28e1326baf5ab9dd52 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/config.cpython-39.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-310.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0373c9878e3d9b8ce93b6ed4001025f2b2487fa5 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-310.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-311.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57b20b21d82233b0a003238f2ab72b772583b42b Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-311.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-38.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6de3ee9d60931e22fa1b65bbcddcacaf994d3415 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-38.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-39.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b0a9e15fabd3f71a1298f0fd2228e22adff22c5 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/customloss.cpython-39.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-310.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2af4a734ea280c0bd9d004cb217d78034d0ead71 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-310.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-311.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2bab808110fe3310c5825a4c9e5f46b5c14e513 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-311.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-38.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..664b2b59b0aeda0eee2078b7756437c8007c3cdf Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-38.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-39.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c356e812265ab4f0d4df08bd0276fbf226a3c302 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/prior.cpython-39.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-310.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eebf83ad60258d3d76e9725d7159586c56f56558 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-310.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-311.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76b8e29964bbcdf6e3df40d98d43460624c5fb4d Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-311.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-38.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a89ec7fc2d8f73821b6e79ca4bb01ce76e3241d Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-38.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-39.pyc b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7820e2fb7c3765c84731c1d62d5e2ab04faebc44 Binary files /dev/null and b/SMPLX/visualize_joint2smpl/joints2smpl/src/__pycache__/smplify.cpython-39.pyc differ diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/config.py b/SMPLX/visualize_joint2smpl/joints2smpl/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f877fe81b7f059cce7d8096cefbeed26c88e4298 --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/src/config.py @@ -0,0 +1,37 @@ +import numpy as np +# Map joints Name to SMPL joints idx +JOINT_MAP = { +'MidHip': 0, +'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, +'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, +'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 'LHand': 22, +'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 'RHand': 23, +'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, +'LCollar':13, 'Rcollar' :14, +'Nose':24, 'REye':26, 'LEye':26, 'REar':27, 'LEar':28, +'LHeel': 31, 'RHeel': 34, +'OP RShoulder': 17, 'OP LShoulder': 16, +'OP RHip': 2, 'OP LHip': 1, +'OP Neck': 12, +} + +full_smpl_idx = range(24) +key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20] + + +AMASS_JOINT_MAP = { +'MidHip': 0, +'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, +'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, +'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, +'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, +'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, +'LCollar':13, 'Rcollar' :14, +} +amass_idx = range(22) +amass_smpl_idx = range(22) + +GMM_MODEL_DIR = "SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/" +SMPL_MEAN_FILE = "SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/neutral_smpl_mean_params.h5" +# for collsion +Part_Seg_DIR = "SMPLX/visualize_joint2smpl/joints2smpl/smpl_models/smplx_parts_segm.pkl" \ No newline at end of file diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/customloss.py b/SMPLX/visualize_joint2smpl/joints2smpl/src/customloss.py new file mode 100644 index 0000000000000000000000000000000000000000..6f89e55461cad68e4a5a329235de789a150ade21 --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/src/customloss.py @@ -0,0 +1,222 @@ +import torch +import torch.nn.functional as F +from SMPLX.visualize_joint2smpl.joints2smpl.src import config + +# Guassian +def gmof(x, sigma): + """ + Geman-McClure error function + """ + x_squared = x ** 2 + sigma_squared = sigma ** 2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + +# angle prior +def angle_prior(pose): + """ + Angle prior that penalizes unnatural bending of the knees and elbows + """ + # We subtract 3 because pose does not include the global rotation of the model + return torch.exp( + pose[:, [55 - 3, 58 - 3, 12 - 3, 15 - 3]] * torch.tensor([1., -1., -1, -1.], device=pose.device)) ** 2 + + +def perspective_projection(points, rotation, translation, + focal_length, camera_center): + """ + This function computes the perspective projection of a set of points. + Input: + points (bs, N, 3): 3D points + rotation (bs, 3, 3): Camera rotation + translation (bs, 3): Camera translation + focal_length (bs,) or scalar: Focal length + camera_center (bs, 2): Camera center + """ + batch_size = points.shape[0] + K = torch.zeros([batch_size, 3, 3], device=points.device) + K[:, 0, 0] = focal_length + K[:, 1, 1] = focal_length + K[:, 2, 2] = 1. + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + return projected_points[:, :, :-1] + + +def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_center, + joints_2d, joints_conf, pose_prior, + focal_length=5000, sigma=100, pose_prior_weight=4.78, + shape_prior_weight=5, angle_prior_weight=15.2, + output='sum'): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(batch_size, -1, -1) + + projected_joints = perspective_projection(model_joints, rotation, camera_t, + focal_length, camera_center) + + # Weighted robust reprojection error + reprojection_error = gmof(projected_joints - joints_2d, sigma) + reprojection_loss = (joints_conf ** 2) * reprojection_error.sum(dim=-1) + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + total_loss = reprojection_loss.sum(dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss + + if output == 'sum': + return total_loss.sum() + elif output == 'reprojection': + return reprojection_loss + + +# --- get camera fitting loss ----- +def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_center, + joints_2d, joints_conf, + focal_length=5000, depth_loss_weight=100): + """ + Loss function for camera optimization. + """ + # Project model joints + batch_size = model_joints.shape[0] + rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand(batch_size, -1, -1) + projected_joints = perspective_projection(model_joints, rotation, camera_t, + focal_length, camera_center) + + # get the indexed four + op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] + op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + reprojection_error_op = (joints_2d[:, op_joints_ind] - + projected_joints[:, op_joints_ind]) ** 2 + reprojection_error_gt = (joints_2d[:, gt_joints_ind] - + projected_joints[:, gt_joints_ind]) ** 2 + + # Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections + # OpenPose joints are more reliable for this task, so we prefer to use them if possible + is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:, None, None] > 0).float() + reprojection_loss = (is_valid * reprojection_error_op + (1 - is_valid) * reprojection_error_gt).sum(dim=(1, 2)) + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight ** 2) * (camera_t[:, 2] - camera_t_est[:, 2]) ** 2 + + total_loss = reprojection_loss + depth_loss + return total_loss.sum() + + + + # #####--- body fitiing loss ----- +def body_fitting_loss_3d(body_pose, preserve_pose, + betas, model_joints, camera_translation, + j3d, pose_prior, + joints3d_conf, + sigma=100, pose_prior_weight=4.78*1.5, + shape_prior_weight=5.0, angle_prior_weight=15.2, + joint_loss_weight=500.0, + pose_preserve_weight=0.0, + use_collision=False, + model_vertices=None, model_faces=None, + search_tree=None, pen_distance=None, filter_faces=None, + collision_loss_weight=1000 + ): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + + #joint3d_loss = (joint_loss_weight ** 2) * gmof((model_joints + camera_translation) - j3d, sigma).sum(dim=-1) + + joint3d_error = gmof((model_joints + camera_translation) - j3d, sigma) + + joint3d_loss_part = (joints3d_conf ** 2) * joint3d_error.sum(dim=-1) + joint3d_loss = ((joint_loss_weight ** 2) * joint3d_loss_part).sum(dim=-1) + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + collision_loss = 0.0 + # Calculate the loss due to interpenetration + if use_collision: + triangles = torch.index_select( + model_vertices, 1, + model_faces).view(batch_size, -1, 3, 3) + + with torch.no_grad(): + collision_idxs = search_tree(triangles) + + # Remove unwanted collisions + if filter_faces is not None: + collision_idxs = filter_faces(collision_idxs) + + if collision_idxs.ge(0).sum().item() > 0: + collision_loss = torch.sum(collision_loss_weight * pen_distance(triangles, collision_idxs)) + + pose_preserve_loss = (pose_preserve_weight ** 2) * ((body_pose - preserve_pose) ** 2).sum(dim=-1) + + # print('joint3d_loss', joint3d_loss.shape) + # print('pose_prior_loss', pose_prior_loss.shape) + # print('angle_prior_loss', angle_prior_loss.shape) + # print('shape_prior_loss', shape_prior_loss.shape) + # print('collision_loss', collision_loss) + # print('pose_preserve_loss', pose_preserve_loss.shape) + + total_loss = joint3d_loss + pose_prior_loss + angle_prior_loss + shape_prior_loss + collision_loss + pose_preserve_loss + + return total_loss.sum() + + +# #####--- get camera fitting loss ----- +def camera_fitting_loss_3d(model_joints, camera_t, camera_t_est, + j3d, joints_category="orig", depth_loss_weight=100.0): + """ + Loss function for camera optimization. + """ + model_joints = model_joints + camera_t + # # get the indexed four + # op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] + # op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] + # + # j3d_error_loss = (j3d[:, op_joints_ind] - + # model_joints[:, op_joints_ind]) ** 2 + + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category=="orig": + select_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="AMASS": + select_joints_ind = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + j3d_error_loss = (j3d[:, select_joints_ind] - + model_joints[:, gt_joints_ind]) ** 2 + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight**2) * (camera_t - camera_t_est)**2 + + total_loss = j3d_error_loss + depth_loss + return total_loss.sum() diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/prior.py b/SMPLX/visualize_joint2smpl/joints2smpl/src/prior.py new file mode 100644 index 0000000000000000000000000000000000000000..7f13806dd1f6607507b0c7e5ad463b3fb0026be8 --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/src/prior.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import sys +import os + +import time +import pickle + +import numpy as np + +import torch +import torch.nn as nn + +DEFAULT_DTYPE = torch.float32 + + +def create_prior(prior_type, **kwargs): + if prior_type == 'gmm': + prior = MaxMixturePrior(**kwargs) + elif prior_type == 'l2': + return L2Prior(**kwargs) + elif prior_type == 'angle': + return SMPLifyAnglePrior(**kwargs) + elif prior_type == 'none' or prior_type is None: + # Don't use any pose prior + def no_prior(*args, **kwargs): + return 0.0 + prior = no_prior + else: + raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') + return prior + + +class SMPLifyAnglePrior(nn.Module): + def __init__(self, dtype=torch.float32, **kwargs): + super(SMPLifyAnglePrior, self).__init__() + + # Indices for the roration angle of + # 55: left elbow, 90deg bend at -np.pi/2 + # 58: right elbow, 90deg bend at np.pi/2 + # 12: left knee, 90deg bend at np.pi/2 + # 15: right knee, 90deg bend at np.pi/2 + angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) + angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) + self.register_buffer('angle_prior_idxs', angle_prior_idxs) + + angle_prior_signs = np.array([1, -1, -1, -1], + dtype=np.float32 if dtype == torch.float32 + else np.float64) + angle_prior_signs = torch.tensor(angle_prior_signs, + dtype=dtype) + self.register_buffer('angle_prior_signs', angle_prior_signs) + + def forward(self, pose, with_global_pose=False): + ''' Returns the angle prior loss for the given pose + + Args: + pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle + representation of the rotations of the joints of the SMPL model. + Kwargs: + with_global_pose: Whether the pose vector also contains the global + orientation of the SMPL model. If not then the indices must be + corrected. + Returns: + A sze (B) tensor containing the angle prior loss for each element + in the batch. + ''' + angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 + return torch.exp(pose[:, angle_prior_idxs] * + self.angle_prior_signs).pow(2) + + +class L2Prior(nn.Module): + def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): + super(L2Prior, self).__init__() + + def forward(self, module_input, *args): + return torch.sum(module_input.pow(2)) + + +class MaxMixturePrior(nn.Module): + + def __init__(self, prior_folder='prior', + num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, + use_merged=True, + **kwargs): + super(MaxMixturePrior, self).__init__() + + if dtype == DEFAULT_DTYPE: + np_dtype = np.float32 + elif dtype == torch.float64: + np_dtype = np.float64 + else: + print('Unknown float type {}, exiting!'.format(dtype)) + sys.exit(-1) + + self.num_gaussians = num_gaussians + self.epsilon = epsilon + self.use_merged = use_merged + gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) + + full_gmm_fn = os.path.join(prior_folder, gmm_fn) + if not os.path.exists(full_gmm_fn): + print('The path to the mixture prior "{}"'.format(full_gmm_fn) + + ' does not exist, exiting!') + sys.exit(-1) + + with open(full_gmm_fn, 'rb') as f: + gmm = pickle.load(f, encoding='latin1') + + if type(gmm) == dict: + means = gmm['means'].astype(np_dtype) + covs = gmm['covars'].astype(np_dtype) + weights = gmm['weights'].astype(np_dtype) + elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): + means = gmm.means_.astype(np_dtype) + covs = gmm.covars_.astype(np_dtype) + weights = gmm.weights_.astype(np_dtype) + else: + print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) + sys.exit(-1) + + self.register_buffer('means', torch.tensor(means, dtype=dtype)) + + self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) + + precisions = [np.linalg.inv(cov) for cov in covs] + precisions = np.stack(precisions).astype(np_dtype) + + self.register_buffer('precisions', + torch.tensor(precisions, dtype=dtype)) + + # The constant term: + sqrdets = np.array([(np.sqrt(np.linalg.det(c))) + for c in gmm['covars']]) + const = (2 * np.pi)**(69 / 2.) + + nll_weights = np.asarray(gmm['weights'] / (const * + (sqrdets / sqrdets.min()))) + nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) + self.register_buffer('nll_weights', nll_weights) + + weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) + self.register_buffer('weights', weights) + + self.register_buffer('pi_term', + torch.log(torch.tensor(2 * np.pi, dtype=dtype))) + + cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) + for cov in covs] + self.register_buffer('cov_dets', + torch.tensor(cov_dets, dtype=dtype)) + + # The dimensionality of the random variable + self.random_var_dim = self.means.shape[1] + + def get_mean(self): + ''' Returns the mean of the mixture ''' + mean_pose = torch.matmul(self.weights, self.means) + return mean_pose + + def merged_log_likelihood(self, pose, betas): + diff_from_mean = pose.unsqueeze(dim=1) - self.means + + prec_diff_prod = torch.einsum('mij,bmj->bmi', + [self.precisions, diff_from_mean]) + diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) + + curr_loglikelihood = 0.5 * diff_prec_quadratic - \ + torch.log(self.nll_weights) + # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + + # self.random_var_dim * self.pi_term + + # diff_prec_quadratic + # ) - torch.log(self.weights) + + min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) + return min_likelihood + + def log_likelihood(self, pose, betas, *args, **kwargs): + ''' Create graph operation for negative log-likelihood calculation + ''' + likelihoods = [] + + for idx in range(self.num_gaussians): + mean = self.means[idx] + prec = self.precisions[idx] + cov = self.covs[idx] + diff_from_mean = pose - mean + + curr_loglikelihood = torch.einsum('bj,ji->bi', + [diff_from_mean, prec]) + curr_loglikelihood = torch.einsum('bi,bi->b', + [curr_loglikelihood, + diff_from_mean]) + cov_term = torch.log(torch.det(cov) + self.epsilon) + curr_loglikelihood += 0.5 * (cov_term + + self.random_var_dim * + self.pi_term) + likelihoods.append(curr_loglikelihood) + + log_likelihoods = torch.stack(likelihoods, dim=1) + min_idx = torch.argmin(log_likelihoods, dim=1) + weight_component = self.nll_weights[:, min_idx] + weight_component = -torch.log(weight_component) + + return weight_component + log_likelihoods[:, min_idx] + + def forward(self, pose, betas): + if self.use_merged: + return self.merged_log_likelihood(pose, betas) + else: + return self.log_likelihood(pose, betas) \ No newline at end of file diff --git a/SMPLX/visualize_joint2smpl/joints2smpl/src/smplify.py b/SMPLX/visualize_joint2smpl/joints2smpl/src/smplify.py new file mode 100644 index 0000000000000000000000000000000000000000..7624a5e433d836bf58f6b7453a3b05a5f83bee05 --- /dev/null +++ b/SMPLX/visualize_joint2smpl/joints2smpl/src/smplify.py @@ -0,0 +1,270 @@ +import torch +import os, sys +import pickle + +sys.path.append(os.path.dirname(__file__)) +from customloss import (camera_fitting_loss_3d, + body_fitting_loss_3d, + ) +from prior import MaxMixturePrior +from SMPLX.visualize_joint2smpl.joints2smpl.src import config + + +@torch.no_grad() +def guess_init_3d(model_joints, + j3d, + joints_category="orig"): + """Initialize the camera translation via triangle similarity, by using the torso joints . + :param model_joints: SMPL model with pre joints + :param j3d: 25x3 array of Kinect Joints + :returns: 3D vector corresponding to the estimated camera translation + """ + # get the indexed four + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category=="orig": + joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="AMASS": + joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1) + init_t = sum_init_t / 4.0 + return init_t + + +# SMPLIfy 3D +class SMPLify3D(): + """Implementation of SMPLify, use 3D joints.""" + + def __init__(self, + smplxmodel, + step_size=1e-2, + batch_size=1, + num_iters=100, + use_collision=False, + use_lbfgs=True, + joints_category="orig", + device=torch.device('cuda:0'), + ): + + # Store options + self.batch_size = batch_size + self.device = device + self.step_size = step_size + + self.num_iters = num_iters + # --- choose optimizer + self.use_lbfgs = use_lbfgs + # GMM pose prior + self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR, + num_gaussians=8, + dtype=torch.float32).to(device) + # collision part + self.use_collision = use_collision + if self.use_collision: + self.part_segm_fn = config.Part_Seg_DIR + + # reLoad SMPL-X model + self.smpl = smplxmodel + + self.model_faces = smplxmodel.faces_tensor.view(-1) + + # select joint joint_category + self.joints_category = joints_category + + if joints_category=="orig": + self.smpl_index = config.full_smpl_idx + self.corr_index = config.full_smpl_idx + elif joints_category=="AMASS": + self.smpl_index = config.amass_smpl_idx + self.corr_index = config.amass_idx + else: + self.smpl_index = None + self.corr_index = None + print("NO SUCH JOINTS CATEGORY!") + + # ---- get the man function here ------ + def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0): + """Perform body fitting. + Input: + init_pose: SMPL pose estimate + init_betas: SMPL betas estimate + init_cam_t: Camera translation estimate + j3d: joints 3d aka keypoints + conf_3d: confidence for 3d joints + seq_ind: index of the sequence + Returns: + vertices: Vertices of optimized shape + joints: 3D joints of optimized shape + pose: SMPL pose parameters of optimized shape + betas: SMPL beta parameters of optimized shape + camera_translation: Camera translation + """ + + # # # add the mesh inter-section to avoid + search_tree = None + pen_distance = None + filter_faces = None + + if self.use_collision: + from mesh_intersection.bvh_search_tree import BVH + import mesh_intersection.loss as collisions_loss + from mesh_intersection.filter_faces import FilterFaces + + search_tree = BVH(max_collisions=8) + + pen_distance = collisions_loss.DistanceFieldPenetrationLoss( + sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True) + + if self.part_segm_fn: + # Read the part segmentation + part_segm_fn = os.path.expandvars(self.part_segm_fn) + with open(part_segm_fn, 'rb') as faces_parents_file: + face_segm_data = pickle.load(faces_parents_file, encoding='latin1') + faces_segm = face_segm_data['segm'] + faces_parents = face_segm_data['parents'] + # Create the module used to filter invalid collision pairs + filter_faces = FilterFaces( + faces_segm=faces_segm, faces_parents=faces_parents, + ign_part_pairs=None).to(device=self.device) + + + # Split SMPL pose to body pose and global orientation + body_pose = init_pose[:, 3:].detach().clone() + global_orient = init_pose[:, :3].detach().clone() + betas = init_betas.detach().clone() + + # use guess 3d to get the initial + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).unsqueeze(1).detach() + camera_translation = init_cam_t.clone() + + preserve_pose = init_pose[:, 3:].detach().clone() + # -------------Step 1: Optimize camera translation and body orientation-------- + # Optimize only camera translation and body orientation + body_pose.requires_grad = False + betas.requires_grad = False + global_orient.requires_grad = True + camera_translation.requires_grad = True + + camera_opt_params = [global_orient, camera_translation] + + if self.use_lbfgs: + camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + for i in range(10): + def closure(): + camera_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + loss = camera_fitting_loss_3d(model_joints, camera_translation, + init_cam_t, j3d, self.joints_category) + loss.backward() + return loss + + camera_optimizer.step(closure) + else: + camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(20): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation, + init_cam_t, j3d[:, self.corr_index], self.joints_category) + camera_optimizer.zero_grad() + loss.backward() + camera_optimizer.step() + + # Fix camera translation after optimizing camera + # --------Step 2: Optimize body joints -------------------------- + # Optimize only the body pose and global orientation of the body + body_pose.requires_grad = True + global_orient.requires_grad = True + camera_translation.requires_grad = True + + # --- if we use the sequence, fix the shape + if seq_ind == 0: + betas.requires_grad = True + body_opt_params = [body_pose, betas, global_orient, camera_translation] + else: + betas.requires_grad = False + body_opt_params = [body_pose, global_orient, camera_translation] + + if self.use_lbfgs: + body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + for i in range(self.num_iters): + def closure(): + body_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + pose_preserve_weight=5.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + loss.backward() + return loss + + body_optimizer.step(closure) + else: + body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(self.num_iters): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + body_optimizer.zero_grad() + loss.backward() + body_optimizer.step() + + # Get final loss value + with torch.no_grad(): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas, return_full_pose=True) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + + vertices = smpl_output.vertices.detach() + joints = smpl_output.joints.detach() + pose = torch.cat([global_orient, body_pose], dim=-1).detach() + betas = betas.detach() + + return vertices, joints, pose, betas, camera_translation, final_loss diff --git a/SMPLX/visualize_joint2smpl/simplify_loc2rot.py b/SMPLX/visualize_joint2smpl/simplify_loc2rot.py new file mode 100644 index 0000000000000000000000000000000000000000..360898b616cb0ee712fa7435013ff63fbffeffe3 --- /dev/null +++ b/SMPLX/visualize_joint2smpl/simplify_loc2rot.py @@ -0,0 +1,124 @@ +import numpy as np +import os +import torch +from SMPLX import smplx +import h5py +from SMPLX.visualize_joint2smpl.joints2smpl.src.smplify import SMPLify3D +from tqdm import tqdm +import argparse + + +class joints2smpl: + + def __init__(self, num_frames, device, model_path=None, json_dict=None): + self.smpl_dir = model_path + self.device = device + # self.device = torch.device("cpu") + self.batch_size = num_frames + self.num_joints = 22 # for HumanML3D + self.joint_category = "AMASS" + self.num_smplify_iters = 100 + self.fix_foot = False + smplmodel = smplx.create(self.smpl_dir, model_type="smpl", gender="neutral", ext="pkl", + batch_size=self.batch_size).to(self.device) + + # ## --- load the mean pose as original ---- + smpl_mean_file = os.path.join(json_dict["joints2smpl"], "neutral_smpl_mean_params.h5") + + file = h5py.File(smpl_mean_file, 'r') + self.init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) + self.init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) + self.cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(self.device) + # + + # # #-------------initialize SMPLify + self.smplify = SMPLify3D(smplxmodel=smplmodel, + batch_size=self.batch_size, + joints_category=self.joint_category, + num_iters=self.num_smplify_iters, + device=self.device) + + + def npy2smpl(self, npy_path): + out_path = npy_path.replace('.npy', '_rot.npy') + motions = np.load(npy_path, allow_pickle=True)[None][0] + # print_batch('', motions) + n_samples = motions['motion'].shape[0] + all_thetas = [] + for sample_i in tqdm(range(n_samples)): + thetas, _ = self.joint2smpl(motions['motion'][sample_i].transpose(2, 0, 1)) # [nframes, njoints, 3] + all_thetas.append(thetas.cpu().numpy()) + motions['motion'] = np.concatenate(all_thetas, axis=0) + print('motions', motions['motion'].shape) + + print(f'Saving [{out_path}]') + np.save(out_path, motions) + exit() + + + def joint2smpl(self, input_joints, init_params=None): + if len(input_joints.shape) == 2: + input_joints = input_joints.reshape(input_joints.shape[0], -1, 3) + + pred_pose = torch.zeros(self.batch_size, 72).to(self.device) + pred_betas = torch.zeros(self.batch_size, 10).to(self.device) + pred_cam_t = torch.zeros(self.batch_size, 3).to(self.device) + keypoints_3d = torch.zeros(self.batch_size, self.num_joints, 3).to(self.device) + + # joints3d = input_joints[idx] # *1.2 #scale problem [check first] + keypoints_3d = torch.Tensor(input_joints).to(self.device).float() + + root_loc = torch.tensor(keypoints_3d[:, 0:1]) #### N * 1 * 3 + root_loc = root_loc - root_loc[[0], :, :] ### N * 1 * 3 + root_loc = root_loc.squeeze(1).detach().cpu().numpy() + # if idx == 0: + if init_params is None: + pred_betas = self.init_mean_shape + pred_pose = self.init_mean_pose + pred_cam_t = self.cam_trans_zero + else: + pred_betas = init_params['betas'] + pred_pose = init_params['pose'] + pred_cam_t = init_params['cam'] + + if self.joint_category == "AMASS": + confidence_input = torch.ones(self.num_joints) + # make sure the foot and ankle + if self.fix_foot == True: + confidence_input[7] = 1.5 + confidence_input[8] = 1.5 + confidence_input[10] = 1.5 + confidence_input[11] = 1.5 + else: + print("Such category not settle down!") + + new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \ + new_opt_cam_t, new_opt_joint_loss = self.smplify( + pred_pose.detach(), + pred_betas.detach(), + pred_cam_t.detach(), + keypoints_3d, + conf_3d=confidence_input.to(self.device), + # seq_ind=idx + ) + + thetas = new_opt_pose.reshape(self.batch_size, 24 * 3) + vecs = thetas.detach().cpu().numpy() + return vecs, root_loc + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--input_path", type=str, required=True, help='Blender file or dir with blender files') + parser.add_argument("--cuda", type=bool, default=True, help='') + parser.add_argument("--device", type=int, default=0, help='') + params = parser.parse_args() + + simplify = joints2smpl(device_id=params.device, cuda=params.cuda) + + if os.path.isfile(params.input_path) and params.input_path.endswith('.npy'): + simplify.npy2smpl(params.input_path) + elif os.path.isdir(params.input_path): + files = [os.path.join(params.input_path, f) for f in os.listdir(params.input_path) if f.endswith('.npy')] + for f in files: + simplify.npy2smpl(f) \ No newline at end of file diff --git a/TADA/__pycache__/anime.cpython-39.pyc b/TADA/__pycache__/anime.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a15021a592fa3b96d42bc90345b066646b3647d Binary files /dev/null and b/TADA/__pycache__/anime.cpython-39.pyc differ diff --git a/TADA/anime.py b/TADA/anime.py new file mode 100644 index 0000000000000000000000000000000000000000..956bf479dd4d6f4c994bb80fa2ba9af616d22cda --- /dev/null +++ b/TADA/anime.py @@ -0,0 +1,159 @@ +import os +import json +import pickle as pkl +import random +import argparse +import cv2 +import torch +from TADA import smplx +import imageio +import numpy as np +from tqdm import tqdm +from PIL import Image +from TADA.lib.common.remesh import subdivide_inorder +from TADA.lib.common.utils import SMPLXSeg +from TADA.lib.common.lbs import warp_points +from TADA.lib.common.obj import compute_normal +import trimesh +import pyrender +from shapely import geometry +import moviepy.editor as mpy +os.environ['PYOPENGL_PLATFORM'] = "egl" + +def build_new_mesh(v, f, vt, ft): + # build a correspondences dictionary from the original mesh indices to the (possibly multiple) texture map indices + f_flat = f.flatten() + ft_flat = ft.flatten() + correspondences = {} + + # traverse and find the corresponding indices in f and ft + for i in range(len(f_flat)): + if f_flat[i] not in correspondences: + correspondences[f_flat[i]] = [ft_flat[i]] + else: + if ft_flat[i] not in correspondences[f_flat[i]]: + correspondences[f_flat[i]].append(ft_flat[i]) + + # build a mesh using the texture map vertices + new_v = np.zeros((v.shape[0], vt.shape[0], 3)) + for old_index, new_indices in correspondences.items(): + for new_index in new_indices: + new_v[:, new_index] = v[:, old_index] + + # define new faces using the texture map faces + f_new = ft + return new_v, f_new + +class Animation: + def __init__(self, ckpt_path, workspace_dir, device="cuda"): + self.device = device + self.SMPLXSeg = SMPLXSeg(workspace_dir) + # load data + init_data = np.load(os.path.join(workspace_dir, "init_body/data.npz")) + self.dense_faces = torch.as_tensor(init_data['dense_faces'], device=self.device) + self.dense_lbs_weights = torch.as_tensor(init_data['dense_lbs_weights'], device=self.device) + self.unique = init_data['unique'] + self.vt = init_data['vt'] + self.ft = init_data['ft'] + + model_params = dict( + model_path=os.path.join(workspace_dir, "smplx/SMPLX_NEUTRAL_2020.npz"), + model_type='smplx', + create_global_orient=True, + create_body_pose=True, + create_betas=True, + create_left_hand_pose=True, + create_right_hand_pose=True, + create_jaw_pose=True, + create_leye_pose=True, + create_reye_pose=True, + create_expression=True, + create_transl=False, + use_pca=False, + flat_hand_mean=False, + num_betas=300, + num_expression_coeffs=100, + num_pca_comps=12, + dtype=torch.float32, + batch_size=1, + ) + self.body_model = smplx.create(**model_params).to(device='cuda') + self.smplx_face = self.body_model.faces.astype(np.int32) + + ckpt_file = os.path.join(workspace_dir, "MESH", ckpt_path, "params.pt") + albedo_path = os.path.join(workspace_dir, "MESH", ckpt_path, "mesh_albedo.png") + self.load_ckpt_data(ckpt_file, albedo_path) + + + def load_ckpt_data(self, ckpt_file, albedo_path): + model_data = torch.load(ckpt_file) + self.expression = model_data["expression"] if "expression" in model_data else None + self.jaw_pose = model_data["jaw_pose"] if "jaw_pose" in model_data else None + + self.betas = model_data['betas'] + self.v_offsets = model_data['v_offsets'] + self.v_offsets[self.SMPLXSeg.eyeball_ids] = 0. + self.v_offsets[self.SMPLXSeg.hands_ids] = 0. + + # tex to trimesh texture + vt = self.vt.copy() + vt[:, 1] = 1 - vt[:, 1] + albedo = Image.open(albedo_path) + + self.raw_albedo = torch.from_numpy(np.array(albedo)) + self.raw_albedo = self.raw_albedo / 255.0 + self.raw_albedo = self.raw_albedo.permute(2, 0, 1) + + self.trimesh_visual = trimesh.visual.TextureVisuals( + uv=vt, + image=albedo, + material=trimesh.visual.texture.SimpleMaterial( + image=albedo, + diffuse=[255, 255, 255, 255], + ambient=[255, 255, 255, 255], + specular=[0, 0, 0, 255], + glossiness=0) + ) + + def forward_mdm(self, motion): + try: + mdm_body_pose = motion["poses"] + translate = torch.from_numpy(motion["trans"]) + except: + translate = torch.from_numpy(motion[:, -3:]) + mdm_body_pose = motion[:, :-3] + mdm_body_pose = mdm_body_pose.reshape(mdm_body_pose.shape[0], -1, 3) + + translate = translate.to(self.device) + scan_v_posed = [] + for i, (pose, t) in tqdm(enumerate(zip(mdm_body_pose, translate))): + body_pose = torch.as_tensor(pose[None, 1:22, :], device=self.device) + global_orient = torch.as_tensor(pose[None, :1, :], device=self.device) + output = self.body_model( + betas=self.betas, + global_orient=global_orient, + jaw_pose=self.jaw_pose, + body_pose=body_pose, + expression=self.expression, + return_verts=True + ) + + v_cano = output.v_posed[0] + # re-mesh + v_cano_dense = subdivide_inorder(v_cano, self.smplx_face[self.SMPLXSeg.remesh_mask], self.unique).squeeze(0) + # add offsets + vn = compute_normal(v_cano_dense, self.dense_faces)[0] + v_cano_dense += self.v_offsets * vn + # do LBS + v_posed_dense = warp_points(v_cano_dense, self.dense_lbs_weights, output.joints_transform[:, :55]) + # translate + v_posed_dense += t - translate[0] + + scan_v_posed.append(v_posed_dense) + + scan_v_posed = torch.cat(scan_v_posed).detach().cpu().numpy() + new_scan_v_posed, new_face = build_new_mesh(scan_v_posed, self.dense_faces, self.vt, self.ft) + new_scan_v_posed = new_scan_v_posed.astype(np.float32) + + return new_scan_v_posed, new_face + diff --git a/TADA/lib/__pycache__/__init__.cpython-310.pyc b/TADA/lib/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d9de62c4c6aaeaf7bc50466149951a8a222f676 Binary files /dev/null and b/TADA/lib/__pycache__/__init__.cpython-310.pyc differ diff --git a/TADA/lib/__pycache__/__init__.cpython-38.pyc b/TADA/lib/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..315dcf2e32cfd74b198cab520006809395c34630 Binary files /dev/null and b/TADA/lib/__pycache__/__init__.cpython-38.pyc differ diff --git a/TADA/lib/__pycache__/__init__.cpython-39.pyc b/TADA/lib/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33a4cadc8bf856b6f4e70b555050b5c13d2d92d8 Binary files /dev/null and b/TADA/lib/__pycache__/__init__.cpython-39.pyc differ diff --git a/TADA/lib/__pycache__/dlmesh.cpython-310.pyc b/TADA/lib/__pycache__/dlmesh.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1c98a96cd6acb4438c30c8379faa601a7055e57 Binary files /dev/null and b/TADA/lib/__pycache__/dlmesh.cpython-310.pyc differ diff --git a/TADA/lib/__pycache__/dlmesh.cpython-38.pyc b/TADA/lib/__pycache__/dlmesh.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b0fed97c261fac69381653bfc0564ce47fde712 Binary files /dev/null and b/TADA/lib/__pycache__/dlmesh.cpython-38.pyc differ diff --git a/TADA/lib/__pycache__/dlmesh.cpython-39.pyc b/TADA/lib/__pycache__/dlmesh.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dca6221ebc7e1841918025c090a752ed4d11b5b0 Binary files /dev/null and b/TADA/lib/__pycache__/dlmesh.cpython-39.pyc differ diff --git a/TADA/lib/__pycache__/dpt.cpython-39.pyc b/TADA/lib/__pycache__/dpt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb4c6fddf23da3cdac0d172b9ab49c58d66d457c Binary files /dev/null and b/TADA/lib/__pycache__/dpt.cpython-39.pyc differ diff --git a/TADA/lib/__pycache__/encoding.cpython-310.pyc b/TADA/lib/__pycache__/encoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f8d26cc705dbeec3fbe10e433e9078e86185310 Binary files /dev/null and b/TADA/lib/__pycache__/encoding.cpython-310.pyc differ diff --git a/TADA/lib/__pycache__/encoding.cpython-38.pyc b/TADA/lib/__pycache__/encoding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a98b6f740cd29d4b73d859ebeb426543108df947 Binary files /dev/null and b/TADA/lib/__pycache__/encoding.cpython-38.pyc differ diff --git a/TADA/lib/__pycache__/encoding.cpython-39.pyc b/TADA/lib/__pycache__/encoding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b3156fe1b3175f20cd11dc0b9d0bf337c890d24 Binary files /dev/null and b/TADA/lib/__pycache__/encoding.cpython-39.pyc differ diff --git a/TADA/lib/__pycache__/provider.cpython-310.pyc b/TADA/lib/__pycache__/provider.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..025779722cc7ae1a77b2966bb1b261c98152d09c Binary files /dev/null and b/TADA/lib/__pycache__/provider.cpython-310.pyc differ diff --git a/TADA/lib/__pycache__/provider.cpython-38.pyc b/TADA/lib/__pycache__/provider.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a956c1b31376e586b5384b590e5d2b7f2dda4a91 Binary files /dev/null and b/TADA/lib/__pycache__/provider.cpython-38.pyc differ diff --git a/TADA/lib/__pycache__/provider.cpython-39.pyc b/TADA/lib/__pycache__/provider.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a61c4a9a7dab87b1c33c415ca70aa8d549c75f47 Binary files /dev/null and b/TADA/lib/__pycache__/provider.cpython-39.pyc differ diff --git a/TADA/lib/__pycache__/trainer.cpython-310.pyc b/TADA/lib/__pycache__/trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8aa9d88dd1a2d6e475f2a4943b1dde0cd9e79990 Binary files /dev/null and b/TADA/lib/__pycache__/trainer.cpython-310.pyc differ diff --git a/TADA/lib/__pycache__/trainer.cpython-38.pyc b/TADA/lib/__pycache__/trainer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..793bec01e67794a00cc23933d2cb8d82901b1380 Binary files /dev/null and b/TADA/lib/__pycache__/trainer.cpython-38.pyc differ diff --git a/TADA/lib/__pycache__/trainer.cpython-39.pyc b/TADA/lib/__pycache__/trainer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f00f67b43ec0cb726e2a57caeb08fc39ee4257a6 Binary files /dev/null and b/TADA/lib/__pycache__/trainer.cpython-39.pyc differ diff --git a/TADA/lib/common/__pycache__/lbs.cpython-310.pyc b/TADA/lib/common/__pycache__/lbs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84701fb54e754e051ffb2e2e76b14c3d5af44e41 Binary files /dev/null and b/TADA/lib/common/__pycache__/lbs.cpython-310.pyc differ diff --git a/TADA/lib/common/__pycache__/lbs.cpython-38.pyc b/TADA/lib/common/__pycache__/lbs.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ff2a5b14faa493197e154197e6634da5b887b4d Binary files /dev/null and b/TADA/lib/common/__pycache__/lbs.cpython-38.pyc differ diff --git a/TADA/lib/common/__pycache__/lbs.cpython-39.pyc b/TADA/lib/common/__pycache__/lbs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbe42616dc01b48c861f3bd0d47c82f6272a81b9 Binary files /dev/null and b/TADA/lib/common/__pycache__/lbs.cpython-39.pyc differ diff --git a/TADA/lib/common/__pycache__/obj.cpython-310.pyc b/TADA/lib/common/__pycache__/obj.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4391f28df540ba4726cb1043249118d7d755133 Binary files /dev/null and b/TADA/lib/common/__pycache__/obj.cpython-310.pyc differ diff --git a/TADA/lib/common/__pycache__/obj.cpython-38.pyc b/TADA/lib/common/__pycache__/obj.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28be1fc34ee83d62bb32ef4f2d5aafaebe94ceb7 Binary files /dev/null and b/TADA/lib/common/__pycache__/obj.cpython-38.pyc differ diff --git a/TADA/lib/common/__pycache__/obj.cpython-39.pyc b/TADA/lib/common/__pycache__/obj.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fa2666c5056c2daa00e1c7cbac968f28014556a Binary files /dev/null and b/TADA/lib/common/__pycache__/obj.cpython-39.pyc differ diff --git a/TADA/lib/common/__pycache__/optimizer.cpython-310.pyc b/TADA/lib/common/__pycache__/optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..569a36b25dbd3a1ec48abbac1c98ec9a17747d09 Binary files /dev/null and b/TADA/lib/common/__pycache__/optimizer.cpython-310.pyc differ diff --git a/TADA/lib/common/__pycache__/optimizer.cpython-38.pyc b/TADA/lib/common/__pycache__/optimizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93c731b1776ecdd82a3f952d5f6a701f298a99c0 Binary files /dev/null and b/TADA/lib/common/__pycache__/optimizer.cpython-38.pyc differ diff --git a/TADA/lib/common/__pycache__/optimizer.cpython-39.pyc b/TADA/lib/common/__pycache__/optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6121d24941b8fd13d4b31d291ad9b7141a03493 Binary files /dev/null and b/TADA/lib/common/__pycache__/optimizer.cpython-39.pyc differ diff --git a/TADA/lib/common/__pycache__/remesh.cpython-310.pyc b/TADA/lib/common/__pycache__/remesh.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c6606c4d39b39470113909a6171de6b4f7eb22 Binary files /dev/null and b/TADA/lib/common/__pycache__/remesh.cpython-310.pyc differ diff --git a/TADA/lib/common/__pycache__/remesh.cpython-38.pyc b/TADA/lib/common/__pycache__/remesh.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df727427d1b473386d1221e63b36a5466a8792c0 Binary files /dev/null and b/TADA/lib/common/__pycache__/remesh.cpython-38.pyc differ diff --git a/TADA/lib/common/__pycache__/remesh.cpython-39.pyc b/TADA/lib/common/__pycache__/remesh.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc2c9ac766c699ffcfd3cfa9141cd0b3f8aa5a74 Binary files /dev/null and b/TADA/lib/common/__pycache__/remesh.cpython-39.pyc differ diff --git a/TADA/lib/common/__pycache__/renderer.cpython-310.pyc b/TADA/lib/common/__pycache__/renderer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ec8e247ff320b6dc37359b50a2b3c270c890ef0 Binary files /dev/null and b/TADA/lib/common/__pycache__/renderer.cpython-310.pyc differ diff --git a/TADA/lib/common/__pycache__/renderer.cpython-38.pyc b/TADA/lib/common/__pycache__/renderer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cb95c70480e375d3817f9e8d859219bc3ccfefa Binary files /dev/null and b/TADA/lib/common/__pycache__/renderer.cpython-38.pyc differ diff --git a/TADA/lib/common/__pycache__/renderer.cpython-39.pyc b/TADA/lib/common/__pycache__/renderer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f155b363f7735da245cc33a15c0cb2b828d0b34f Binary files /dev/null and b/TADA/lib/common/__pycache__/renderer.cpython-39.pyc differ diff --git a/TADA/lib/common/__pycache__/utils.cpython-310.pyc b/TADA/lib/common/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b3bc92a29c757ad5aec0ff3d6a629b994eb469e Binary files /dev/null and b/TADA/lib/common/__pycache__/utils.cpython-310.pyc differ diff --git a/TADA/lib/common/__pycache__/utils.cpython-38.pyc b/TADA/lib/common/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c1ec8ccfe1ed9c070986721202974f100d5b174 Binary files /dev/null and b/TADA/lib/common/__pycache__/utils.cpython-38.pyc differ diff --git a/TADA/lib/common/__pycache__/utils.cpython-39.pyc b/TADA/lib/common/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdd068a53efa7d06ed9829a7e3179a24d14e3edf Binary files /dev/null and b/TADA/lib/common/__pycache__/utils.cpython-39.pyc differ diff --git a/TADA/lib/common/__pycache__/visual.cpython-310.pyc b/TADA/lib/common/__pycache__/visual.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..564cd2379afdf833feb7e70d315e19ccbb49a60a Binary files /dev/null and b/TADA/lib/common/__pycache__/visual.cpython-310.pyc differ diff --git a/TADA/lib/common/__pycache__/visual.cpython-38.pyc b/TADA/lib/common/__pycache__/visual.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f28846ba1494e9a7c6931b577895f334f57aa2 Binary files /dev/null and b/TADA/lib/common/__pycache__/visual.cpython-38.pyc differ diff --git a/TADA/lib/common/__pycache__/visual.cpython-39.pyc b/TADA/lib/common/__pycache__/visual.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c53f3b4f3749d8b37fed62a2c55bc0e935344d27 Binary files /dev/null and b/TADA/lib/common/__pycache__/visual.cpython-39.pyc differ diff --git a/TADA/lib/common/lbs.py b/TADA/lib/common/lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..fb0e59b305ded03a9adda7daf334263da5bab8c2 --- /dev/null +++ b/TADA/lib/common/lbs.py @@ -0,0 +1,64 @@ +import torch + +def linear_blend_skinning(points, weight, joint_transform, return_vT=False, inverse=False): + """ + Args: + points: FloatTensor [batch, N, 3] + weight: FloatTensor [batch, N, K] + joint_transform: FloatTensor [batch, K, 4, 4] + return_vT: return vertex transform matrix if true + inverse: bool inverse LBS if true + Return: + points_deformed: FloatTensor [batch, N, 3] + """ + if not weight.shape[0] == joint_transform.shape[0]: + raise AssertionError('batch should be same,', weight.shape, joint_transform.shape) + + if not torch.is_tensor(points): + points = torch.as_tensor(points).float() + if not torch.is_tensor(weight): + weight = torch.as_tensor(weight).float() + if not torch.is_tensor(joint_transform): + joint_transform = torch.as_tensor(joint_transform).float() + + batch = joint_transform.size(0) + vT = torch.bmm(weight, joint_transform.contiguous().view(batch, -1, 16)).view(batch, -1, 4, 4) + if inverse: + vT = torch.inverse(vT.view(-1, 4, 4)).view(batch, -1, 4, 4) + + R, T = vT[:, :, :3, :3], vT[:, :, :3, 3] + deformed_points = torch.matmul(R, points.unsqueeze(-1)).squeeze(-1) + T + + if return_vT: + return deformed_points, vT + return deformed_points + + + +def warp_points(points, skin_weights, joint_transform, inverse=False): + """ + Warp a canonical point cloud to multiple posed spaces and project to image space + Args: + points: [N, 3] Tensor of 3D points + skin_weights: [N, J] corresponding skinning weights of points + joint_transform: [B, J, 4, 4] joint transform matrix of a batch of poses + Returns: + posed_points [B, N, 3] warpped points in posed space + """ + + if not torch.is_tensor(points): + points = torch.as_tensor(points).float() + if not torch.is_tensor(joint_transform): + joint_transform = torch.as_tensor(joint_transform).float() + if not torch.is_tensor(skin_weights): + skin_weights = torch.as_tensor(skin_weights).float() + + batch = joint_transform.shape[0] + if points.dim() == 2: + points = points.expand(batch, -1, -1) + # warping + points_posed, vT = linear_blend_skinning(points, + skin_weights.expand(batch, -1, -1), + joint_transform, return_vT=True, inverse=inverse) + + return points_posed diff --git a/TADA/lib/common/mesh_utils.py b/TADA/lib/common/mesh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..87fd2f86d6e6faac257fd894889d2da26d77ce0d --- /dev/null +++ b/TADA/lib/common/mesh_utils.py @@ -0,0 +1,559 @@ +import torch +import kaolin as kal +import numpy as np +from pathlib import Path +import numpy as np +import pymeshlab as pml + +if torch.cuda.is_available(): + device = torch.device("cuda:0") + torch.cuda.set_device(device) +else: + device = torch.device("cpu") + + +def get_camera_from_view(elev, azim, r=3.0): + x = r * torch.cos(azim) * torch.sin(elev) + y = r * torch.sin(azim) * torch.sin(elev) + z = r * torch.cos(elev) + # print(elev,azim,x,y,z) + + pos = torch.tensor([x, y, z]).unsqueeze(0) + look_at = -pos + direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0) + + camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction) + return camera_proj + + +def get_camera_from_view2(elev, azim, r=3.0): + x = r * torch.cos(elev) * torch.cos(azim) + y = r * torch.sin(elev) + z = r * torch.cos(elev) * torch.sin(azim) + # print(elev,azim,x,y,z) + + pos = torch.tensor([x, y, z]).unsqueeze(0) + look_at = -pos + direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0) + + camera_proj = kal.render.camera.generate_transformation_matrix(pos, look_at, direction) + return camera_proj + + +def get_homogenous_coordinates(V): + N, D = V.shape + bottom = torch.ones(N, device=device).unsqueeze(1) + return torch.cat([V, bottom], dim=1) + + +def apply_affine(verts, A): + verts = verts.to(device) + verts = get_homogenous_coordinates(verts) + A = torch.cat([A, torch.tensor([0.0, 0.0, 0.0, 1.0], device=device).unsqueeze(0)], dim=0) + transformed_verts = A @ verts.T + transformed_verts = transformed_verts[:-1] + return transformed_verts.T + + +def standardize_mesh(mesh): + verts = mesh.vertices + center = verts.mean(dim=0) + verts -= center + scale = torch.std(torch.norm(verts, p=2, dim=1)) + verts /= scale + mesh.vertices = verts + return mesh + + +def normalize_mesh(mesh): + verts = mesh.vertices + + # Compute center of bounding box + # center = torch.mean(torch.column_stack([torch.max(verts, dim=0)[0], torch.min(verts, dim=0)[0]])) + center = verts.mean(dim=0) + verts = verts - center + scale = torch.max(torch.norm(verts, p=2, dim=1)) + verts = verts / scale + mesh.vertices = verts + return mesh + + +def get_texture_map_from_color(mesh, color, H=224, W=224): + num_faces = mesh.faces.shape[0] + texture_map = torch.zeros(1, H, W, 3).to(device) + texture_map[:, :, :] = color + return texture_map.permute(0, 3, 1, 2) + + +def get_face_attributes_from_color(mesh, color): + num_faces = mesh.faces.shape[0] + face_attributes = torch.zeros(1, num_faces, 3, 3).to(device) + face_attributes[:, :, :] = color + return face_attributes + + +def sample_bary(faces, vertices): + num_faces = faces.shape[0] + num_vertices = vertices.shape[0] + + # get random barycentric for each face TODO: improve sampling + A = torch.randn(num_faces) + B = torch.randn(num_faces) * (1 - A) + C = 1 - (A + B) + bary = torch.vstack([A, B, C]).to(device) + + # compute xyz of new vertices and new uvs (if mesh has them) + new_vertices = torch.zeros(num_faces, 3).to(device) + new_uvs = torch.zeros(num_faces, 2).to(device) + face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces) + for f in range(num_faces): + new_vertices[f] = bary[:, f] @ face_verts[:, f] + new_vertices = torch.cat([vertices, new_vertices]) + return new_vertices + + +def _add_vertices(mesh): + faces = torch.as_tensor(mesh.faces) + vertices = torch.as_tensor(mesh.vertices) + num_faces = faces.shape[0] + num_vertices = vertices.shape[0] + + # get random barycentric for each face TODO: improve sampling + A = torch.randn(num_faces) + B = torch.randn(num_faces) * (1 - A) + C = 1 - (A + B) + bary = torch.vstack([A, B, C]).to(device) + + # compute xyz of new vertices and new uvs (if mesh has them) + new_vertices = torch.zeros(num_faces, 3).to(device) + new_uvs = torch.zeros(num_faces, 2).to(device) + face_verts = kal.ops.mesh.index_vertices_by_faces(vertices.unsqueeze(0), faces) + face_uvs = mesh.face_uvs + for f in range(num_faces): + new_vertices[f] = bary[:, f] @ face_verts[:, f] + if face_uvs is not None: + new_uvs[f] = bary[:, f] @ face_uvs[:, f] + + # update face and face_uvs of mesh + new_vertices = torch.cat([vertices, new_vertices]) + new_faces = [] + new_face_uvs = [] + new_vertex_normals = [] + for i in range(num_faces): + old_face = faces[i] + a, b, c = old_face[0], old_face[1], old_face[2] + d = num_vertices + i + new_faces.append(torch.tensor([a, b, d]).to(device)) + new_faces.append(torch.tensor([a, d, c]).to(device)) + new_faces.append(torch.tensor([d, b, c]).to(device)) + if face_uvs is not None: + old_face_uvs = face_uvs[0, i] + a, b, c = old_face_uvs[0], old_face_uvs[1], old_face_uvs[2] + d = new_uvs[i] + new_face_uvs.append(torch.vstack([a, b, d])) + new_face_uvs.append(torch.vstack([a, d, c])) + new_face_uvs.append(torch.vstack([d, b, c])) + if mesh.face_normals is not None: + new_vertex_normals.append(mesh.face_normals[i]) + else: + e1 = vertices[b] - vertices[a] + e2 = vertices[c] - vertices[a] + norm = torch.cross(e1, e2) + norm /= torch.norm(norm) + + # Double check sign against existing vertex normals + if torch.dot(norm, mesh.vertex_normals[a]) < 0: + norm = -norm + + new_vertex_normals.append(norm) + + vertex_normals = torch.cat([mesh.vertex_normals, torch.stack(new_vertex_normals)]) + + if face_uvs is not None: + new_face_uvs = torch.vstack(new_face_uvs).unsqueeze(0).view(1, 3 * num_faces, 3, 2) + new_faces = torch.vstack(new_faces) + + return new_vertices, new_faces, vertex_normals, new_face_uvs + + + +def get_rgb_per_vertex(vertices, faces, face_rgbs): + num_vertex = vertices.shape[0] + num_faces = faces.shape[0] + vertex_color = torch.zeros(num_vertex, 3) + + for v in range(num_vertex): + for f in range(num_faces): + face = num_faces[f] + if v in face: + vertex_color[v] = face_rgbs[f] + return face_rgbs + + +def get_barycentric(p, faces): + # faces num_points x 3 x 3 + # p num_points x 3 + # source: https://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates + + a, b, c = faces[:, 0], faces[:, 1], faces[:, 2] + + v0, v1, v2 = b - a, c - a, p - a + d00 = torch.sum(v0 * v0, dim=1) + d01 = torch.sum(v0 * v1, dim=1) + d11 = torch.sum(v1 * v1, dim=1) + d20 = torch.sum(v2 * v0, dim=1) + d21 = torch.sum(v2 * v1, dim=1) + denom = d00 * d11 - d01 * d01 + v = (d11 * d20 - d01 * d21) / denom + w = (d00 * d21 - d01 * d20) / denom + u = 1 - (w + v) + + return torch.vstack([u, v, w]).T + + +def get_uv_assignment(num_faces): + M = int(np.ceil(np.sqrt(num_faces))) + uv_map = torch.zeros(1, num_faces, 3, 2).to(device) + px, py = 0, 0 + count = 0 + for i in range(M): + px = 0 + for j in range(M): + uv_map[:, count] = torch.tensor([[px, py], + [px + 1, py], + [px + 1, py + 1]]) + px += 2 + count += 1 + if count >= num_faces: + hw = torch.max(uv_map.view(-1, 2), dim=0)[0] + uv_map = (uv_map - hw / 2.0) / (hw / 2) + return uv_map + py += 2 + + +def get_texture_visual(res, nt, mesh): + faces_vt = kal.ops.mesh.index_vertices_by_faces(mesh.vertices.unsqueeze(0), mesh.faces).squeeze(0) + + # as to not include encpoint, gen res+1 points and take first res + uv = torch.cartesian_prod(torch.linspace(-1, 1, res + 1)[:-1], torch.linspace(-1, 1, res + 1))[:-1].to(device) + image = torch.zeros(res, res, 3).to(device) + # image[:,:,:] = torch.tensor([0.0,1.0,0.0]).to(device) + image = image.permute(2, 0, 1) + num_faces = mesh.faces.shape[0] + uv_map = get_uv_assignment(num_faces).squeeze(0) + + zero = torch.tensor([0.0, 0.0, 0.0]).to(device) + one = torch.tensor([1.0, 1.0, 1.0]).to(device) + + for face in range(num_faces): + bary = get_barycentric(uv, uv_map[face].repeat(len(uv), 1, 1)) + + maskA = torch.logical_and(bary[:, 0] >= 0.0, bary[:, 0] <= 1.0) + maskB = torch.logical_and(bary[:, 1] >= 0.0, bary[:, 1] <= 1.0) + maskC = torch.logical_and(bary[:, 2] >= 0.0, bary[:, 2] <= 1.0) + + mask = torch.logical_and(maskA, maskB) + mask = torch.logical_and(maskC, mask) + + inside_triangle = bary[mask] + inside_triangle_uv = inside_triangle @ uv_map[face] + inside_triangle_xyz = inside_triangle @ faces_vt[face] + inside_triangle_rgb = nt(inside_triangle_xyz) + + pixels = (inside_triangle_uv + 1.0) / 2.0 + pixels = pixels * res + pixels = torch.floor(pixels).type(torch.int64) + + image[:, pixels[:, 0], pixels[:, 1]] = inside_triangle_rgb.T + + return image + + +# Get rotation matrix about vector through origin +def getRotMat(axis, theta): + """ + axis: np.array, normalized vector + theta: radians + """ + import math + + axis = axis / np.linalg.norm(axis) + cprod = np.array([[0, -axis[2], axis[1]], + [axis[2], 0, -axis[0]], + [-axis[1], axis[0], 0]]) + rot = math.cos(theta) * np.identity(3) + math.sin(theta) * cprod + \ + (1 - math.cos(theta)) * np.outer(axis, axis) + return rot + + +# Map vertices and subset of faces to 0-indexed vertices, keeping only relevant vertices +def trimMesh(vertices, faces): + unique_v = np.sort(np.unique(faces.flatten())) + v_val = np.arange(len(unique_v)) + v_map = dict(zip(unique_v, v_val)) + new_faces = np.array([v_map[i] for i in faces.flatten()]).reshape(faces.shape[0], faces.shape[1]) + new_v = vertices[unique_v] + + return new_v, new_faces + + +# ================== VISUALIZATION ======================= +# Back out camera parameters from view transform matrix +def extract_from_gl_viewmat(gl_mat): + gl_mat = gl_mat.reshape(4, 4) + s = gl_mat[0, :3] + u = gl_mat[1, :3] + f = -1 * gl_mat[2, :3] + coord = gl_mat[:3, 3] # first 3 entries of the last column + camera_location = np.array([-s, -u, f]).T @ coord + target = camera_location + f * 10 # any scale + return camera_location, target + + +def psScreenshot(vertices, faces, axis, angles, save_path, name="mesh", frame_folder="frames", scalars=None, + colors=None, + defined_on="faces", highlight_faces=None, highlight_color=[1, 0, 0], highlight_radius=None, + cmap=None, sminmax=None, cpos=None, clook=None, save_video=False, save_base=False, + ground_plane="tile_reflection", debug=False, edge_color=[0, 0, 0], edge_width=1, material=None): + import polyscope as ps + + ps.init() + # Set camera to look at same fixed position in centroid of original mesh + # center = np.mean(vertices, axis = 0) + # pos = center + np.array([0, 0, 3]) + # ps.look_at(pos, center) + ps.set_ground_plane_mode(ground_plane) + + frame_path = f"{save_path}/{frame_folder}" + if save_base == True: + ps_mesh = ps.register_surface_mesh("mesh", vertices, faces, enabled=True, + edge_color=edge_color, edge_width=edge_width, material=material) + ps.screenshot(f"{frame_path}/{name}.png") + ps.remove_all_structures() + Path(frame_path).mkdir(parents=True, exist_ok=True) + # Convert 2D to 3D by appending Z-axis + if vertices.shape[1] == 2: + vertices = np.concatenate((vertices, np.zeros((len(vertices), 1))), axis=1) + + for i in range(len(angles)): + rot = getRotMat(axis, angles[i]) + rot_verts = np.transpose(rot @ np.transpose(vertices)) + + ps_mesh = ps.register_surface_mesh("mesh", rot_verts, faces, enabled=True, + edge_color=edge_color, edge_width=edge_width, material=material) + if scalars is not None: + ps_mesh.add_scalar_quantity(f"scalar", scalars, defined_on=defined_on, + cmap=cmap, enabled=True, vminmax=sminmax) + if colors is not None: + ps_mesh.add_color_quantity(f"color", colors, defined_on=defined_on, + enabled=True) + if highlight_faces is not None: + # Create curve to highlight faces + curve_v, new_f = trimMesh(rot_verts, faces[highlight_faces, :]) + curve_edges = [] + for face in new_f: + curve_edges.extend( + [[face[0], face[1]], [face[1], face[2]], [face[2], face[0]]]) + curve_edges = np.array(curve_edges) + ps_curve = ps.register_curve_network("curve", curve_v, curve_edges, color=highlight_color, + radius=highlight_radius) + + if cpos is None or clook is None: + ps.reset_camera_to_home_view() + else: + ps.look_at(cpos, clook) + + if debug == True: + ps.show() + ps.screenshot(f"{frame_path}/{name}_{i}.png") + ps.remove_all_structures() + if save_video == True: + import glob + from PIL import Image + fp_in = f"{frame_path}/{name}_*.png" + fp_out = f"{save_path}/{name}.gif" + img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))] + img.save(fp=fp_out, format='GIF', append_images=imgs, + save_all=True, duration=200, loop=0) + + +# ================== POSITIONAL ENCODERS ============================= +class FourierFeatureTransform(torch.nn.Module): + """ + An implementation of Gaussian Fourier feature mapping. + "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains": + https://arxiv.org/abs/2006.10739 + https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html + Given an input of size [batches, num_input_channels, width, height], + returns a tensor of size [batches, mapping_size*2, width, height]. + """ + + def __init__(self, num_input_channels, mapping_size=256, scale=10, exclude=0): + super().__init__() + + self._num_input_channels = num_input_channels + self._mapping_size = mapping_size + self.exclude = exclude + B = torch.randn((num_input_channels, mapping_size)) * scale + B_sort = sorted(B, key=lambda x: torch.norm(x, p=2)) + self._B = torch.stack(B_sort) # for sape + + def forward(self, x): + # assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim()) + + batches, channels = x.shape + + assert channels == self._num_input_channels, \ + "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels) + + # Make shape compatible for matmul with _B. + # From [B, C, W, H] to [(B*W*H), C]. + # x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels) + + res = x @ self._B.to(x.device) + + # From [(B*W*H), C] to [B, W, H, C] + # x = x.view(batches, width, height, self._mapping_size) + # From [B, W, H, C] to [B, C, W, H] + # x = x.permute(0, 3, 1, 2) + + res = 2 * np.pi * res + return torch.cat([x, torch.sin(res), torch.cos(res)], dim=1) + + +def poisson_mesh_reconstruction(points, normals=None): + # points/normals: [N, 3] np.ndarray + + import open3d as o3d + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + + # outlier removal + pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10) + + # normals + if normals is None: + pcd.estimate_normals() + else: + pcd.normals = o3d.utility.Vector3dVector(normals[ind]) + + # visualize + o3d.visualization.draw_geometries([pcd], point_show_normal=False) + + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9) + # vertices_to_remove = densities < np.quantile(densities, 0.1) + # mesh.remove_vertices_by_mask(vertices_to_remove) + + # visualize + o3d.visualization.draw_geometries([mesh]) + + vertices = np.asarray(mesh.vertices) + triangles = np.asarray(mesh.triangles) + + print(f'[INFO] poisson mesh reconstruction: {points.shape} --> V: {vertices.shape} / F:{triangles.shape}') + + return vertices, triangles + + +def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True): + # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect. + + _ori_vert_shape = verts.shape + _ori_face_shape = faces.shape + + if backend == 'pyfqmr': + import pyfqmr + solver = pyfqmr.Simplify() + solver.setMesh(verts, faces) + solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False) + verts, faces, normals = solver.getMesh() + else: + + m = pml.Mesh(verts, faces) + ms = pml.MeshSet() + ms.add_mesh(m, 'mesh') # will copy! + + # filters + # ms.meshing_decimation_clustering(threshold=pml.Percentage(1)) + ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement) + + if remesh: + # ms.apply_coord_taubin_smoothing() + ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1)) + + # extract mesh + m = ms.current_mesh() + verts = m.vertex_matrix() + faces = m.face_matrix() + + print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') + + return verts, faces + + +def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=False): + # verts: [N, 3] + # faces: [N, 3] + + _ori_vert_shape = verts.shape + _ori_face_shape = faces.shape + + m = pml.Mesh(verts, faces) + ms = pml.MeshSet() + ms.add_mesh(m, 'mesh') # will copy! + + # filters + ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces + + if v_pct > 0: + ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal + + ms.meshing_remove_duplicate_faces() # faces defined by the same verts + ms.meshing_remove_null_faces() # faces with area == 0 + + if min_d > 0: + ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d)) + + if min_f > 0: + ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f) + + if repair: + # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True) + ms.meshing_repair_non_manifold_edges(method=0) + ms.meshing_repair_non_manifold_vertices(vertdispratio=0) + + if remesh: + # ms.apply_coord_taubin_smoothing() + ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1)) + + # extract mesh + m = ms.current_mesh() + verts = m.vertex_matrix() + faces = m.face_matrix() + + print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') + + return verts, faces + + +def laplace_regularizer_const(v_pos, t_pos_idx): + term = torch.zeros_like(v_pos) + norm = torch.zeros_like(v_pos[..., 0:1]) + + v0 = v_pos[t_pos_idx[:, 0], :] + v1 = v_pos[t_pos_idx[:, 1], :] + v2 = v_pos[t_pos_idx[:, 2], :] + + term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1, 3), (v1 - v0) + (v2 - v0)) + term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1, 3), (v0 - v1) + (v2 - v1)) + term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1, 3), (v0 - v2) + (v1 - v2)) + + two = torch.ones_like(v0) * 2.0 + norm.scatter_add_(0, t_pos_idx[:, 0:1], two) + norm.scatter_add_(0, t_pos_idx[:, 1:2], two) + norm.scatter_add_(0, t_pos_idx[:, 2:3], two) + + term = term / torch.clamp(norm, min=1.0) + + return torch.mean(term ** 2) diff --git a/TADA/lib/common/obj.py b/TADA/lib/common/obj.py new file mode 100644 index 0000000000000000000000000000000000000000..b87dd980de59a36e51cd16fc90e2be09872a4bb3 --- /dev/null +++ b/TADA/lib/common/obj.py @@ -0,0 +1,398 @@ +import os +import cv2 +import torch +import numpy as np +import pymeshlab +import trimesh +from .utils import dot, safe_normalize + + +def length(x, eps=1e-20): + return torch.sqrt(torch.clamp(dot(x, x), min=eps)) + + +def compute_normal(vertices, faces): + if not isinstance(vertices, torch.Tensor): + vertices = torch.as_tensor(vertices).float() + if not isinstance(faces, torch.Tensor): + faces = torch.as_tensor(faces).long() + + i0, i1, i2 = faces[:, 0].long(), faces[:, 1].long(), faces[:, 2].long() + + v0, v1, v2 = vertices[i0, :], vertices[i1, :], vertices[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + vn = torch.zeros_like(vertices) + vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + vn = torch.where(dot(vn, vn) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) + vn = safe_normalize(vn) + + face_normals = safe_normalize(face_normals) + return vn, faces + + +class Mesh(): + def __init__(self, v=None, f=None, vn=None, fn=None, vt=None, ft=None, albedo=None, device=None, base=None, + init_empty_tex=False, albedo_res=1024): + self.v = v + self.vn = vn + self.vt = vt + self.f = f + self.fn = fn + self.ft = ft + self.v_tng = None + self.f_tng = None + # only support a single albedo + if init_empty_tex: + self.albedo = torch.zeros((albedo_res, albedo_res, 3), dtype=torch.float32, device=device) + + else: + self.albedo = albedo + self.device = device + + if isinstance(base, Mesh): + for name in ['v', 'vn', 'vt', 'f', 'fn', 'ft', 'albedo', 'v_tng', 'f_tng']: + if getattr(self, name) is None: + setattr(self, name, getattr(base, name)) + + # load from obj file + @classmethod + def load_obj(cls, path, albedo_path=None, device=None, init_empty_tex=False, albedo_res=1024, + uv_path=None, normalize=False): + + assert os.path.splitext(path)[-1] == '.obj' + + mesh = cls() + + # device + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + mesh.device = device + + # try to find texture from mtl file + if albedo_path is None: + mtl_path = path.replace('.obj', '.mtl') + if os.path.exists(mtl_path): + with open(mtl_path, 'r') as f: + lines = f.readlines() + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: continue + prefix = split_line[0] + # NOTE: simply use the first map_Kd as albedo! + if 'map_Kd' in prefix: + albedo_path = os.path.join(os.path.dirname(path), split_line[1]) + print(f'[load_obj] use albedo from: {albedo_path}') + break + + if init_empty_tex or albedo_path is None or not os.path.exists(albedo_path): + # init an empty texture + print(f'[load_obj] init empty albedo!') + albedo = np.ones((albedo_res, albedo_res, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color + else: + albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED) + albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB) + albedo = cv2.resize(albedo, (albedo_res, albedo_res)) + albedo = albedo.astype(np.float32) / 255 + + mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device) + + # load obj + with open(path, 'r') as f: + lines = f.readlines() + + def parse_f_v(fv): + # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided) + # supported forms: + # f v1 v2 v3 + # f v1/vt1 v2/vt2 v3/vt3 + # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3 + # f v1//vn1 v2//vn2 v3//vn3 + xs = [int(x) - 1 if x != '' else -1 for x in fv.split('/')] + xs.extend([-1] * (3 - len(xs))) + return xs[0], xs[1], xs[2] + + # NOTE: we ignore usemtl, and assume the mesh ONLY uses one material (first in mtl) + vertices, texcoords, normals = [], [], [] + faces, tfaces, nfaces = [], [], [] + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: continue + # v/vn/vt + prefix = split_line[0].lower() + if prefix == 'v': + vertices.append([float(v) for v in split_line[1:]]) + elif prefix == 'vn': + normals.append([float(v) for v in split_line[1:]]) + elif prefix == 'vt': + val = [float(v) for v in split_line[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == 'f': + vs = split_line[1:] + nv = len(vs) + v0, t0, n0 = parse_f_v(vs[0]) + for i in range(nv - 2): # triangulate (assume vertices are ordered) + v1, t1, n1 = parse_f_v(vs[i + 1]) + v2, t2, n2 = parse_f_v(vs[i + 2]) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = torch.tensor(texcoords, dtype=torch.float32, device=device) if len(texcoords) > 0 else None + mesh.vn = torch.tensor(normals, dtype=torch.float32, device=device) if len(normals) > 0 else None + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = torch.tensor(tfaces, dtype=torch.int32, device=device) if texcoords is not None else None + mesh.fn = torch.tensor(nfaces, dtype=torch.int32, device=device) if normals is not None else None + + # auto-normalize + # Skip this + if normalize: + mesh.auto_size() + + print(f'[load_obj] v: {mesh.v.shape}, f: {mesh.f.shape}') + + # auto-fix normal + if mesh.vn is None: + mesh.auto_normal() + + print(f'[load_obj] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}') + + # auto-fix texture + if mesh.vt is None: + mesh.auto_uv(cache_path=uv_path) + + print(f'[load_obj] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}') + + return mesh + + @classmethod + def load_albedo(cls, albedo_path): + albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED) + albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB) + albedo = albedo.astype(np.float32) / 255 + return albedo + + # aabb + def aabb(self): + return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values + + # unit size + @torch.no_grad() + def auto_size(self): # to [-0.5, 0.5] + vmin, vmax = self.aabb() + scale = 1 / torch.max(vmax - vmin).item() + self.v = self.v - (vmax + vmin) / 2 # Center mesh on origin + self.v = self.v * scale + + def auto_normal(self): + self.vn, self.fn = compute_normal(self.v, self.f) + self.fn = self.f + + @torch.no_grad() + def auto_uv(self, cache_path="", v=None, f=None): + # try to load cache + if cache_path is not None and os.path.exists(cache_path): + data = np.load(cache_path) + vt_np, ft_np = data['vt'], data['ft'] + else: + import xatlas + if v is not None and f is not None: + v_np = v.cpu().numpy() + f_np = f.int().cpu().numpy() + else: + v_np = self.v.cpu().numpy() + f_np = self.f.int().cpu().numpy() + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + chart_options.max_iterations = 4 + atlas.generate(chart_options=chart_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + # save to cache + # np.savez(cache_path, vt=vt_np, ft=ft_np) + + vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device) + ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device) + + self.vt = vt + self.ft = ft + return vt, ft + + def compute_tangents(self): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0, 3): + pos[i] = self.v[self.f[:, i]] + tex[i] = self.vt[self.ft[:, i]] + vn_idx[i] = self.fn[:, i] + + tangents = torch.zeros_like(self.vn) + tansum = torch.zeros_like(self.vn) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] + uve2 = tex[2] - tex[0] + pe1 = pos[1] - pos[0] + pe2 = pos[2] - pos[0] + + nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]) + denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]) + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) + + # Update all 3 vertices + for i in range(0, 3): + idx = vn_idx[i][:, None].repeat(1, 3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + tansum.scatter_add_(0, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1 + tangents = tangents / tansum + + # Normalize and make sure tangent is perpendicular to normal + tangents = safe_normalize(tangents) + tangents = safe_normalize(tangents - dot(tangents, self.vn) * self.vn) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + self.v_tng = tangents + self.f_tng = self.fn + + # write to obj file + def write(self, path): + + mtl_path = path.replace('.obj', '.mtl') + albedo_path = path.replace('.obj', '_albedo.png') + + v_np = self.v.cpu().numpy() + vt_np = self.vt.cpu().numpy() if self.vt is not None else None + vn_np = self.vn.cpu().numpy() if self.vn is not None else None + f_np = self.f.cpu().numpy() + ft_np = self.ft.cpu().numpy() if self.ft is not None else None + fn_np = self.fn.cpu().numpy() if self.fn is not None else None + + with open(path, "w") as fp: + fp.write(f'mtllib {os.path.basename(mtl_path)} \n') + + for v in v_np: + fp.write(f'v {v[0]} {v[1]} {v[2]} \n') + + for v in vt_np: + fp.write(f'vt {v[0]} {1 - v[1]} \n') + + for v in vn_np: + fp.write(f'vn {v[0]} {v[1]} {v[2]} \n') + + fp.write(f'usemtl defaultMat \n') + for i in range(len(f_np)): + fp.write( + f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \ + {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \ + {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n') + + with open(mtl_path, "w") as fp: + fp.write(f'newmtl defaultMat \n') + fp.write(f'Ka 1 1 1 \n') + fp.write(f'Kd 1 1 1 \n') + fp.write(f'Ks 0 0 0 \n') + fp.write(f'Tr 1 \n') + fp.write(f'illum 1 \n') + fp.write(f'Ns 0 \n') + fp.write(f'map_Kd {os.path.basename(albedo_path)} \n') + + albedo = self.albedo.detach().cpu().numpy() + albedo = (albedo * 255).astype(np.uint8) + cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)) + + def set_albedo(self, albedo): + self.albedo = torch.sigmoid(albedo) + + def set_uv(self, vt, ft): + self.vt = vt + self.ft = ft + + def auto_uv_face_att(self): + import kaolin as kal + self.uv_face_att = kal.ops.mesh.index_vertices_by_faces( + self.vt.unsqueeze(0), + self.ft.long()) + + +def save_obj_mesh(mesh_path, verts, faces=None): + file = open(mesh_path, 'w') + + for v in verts: + file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) + + if faces is not None: + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2])) + file.close() + + +def keep_largest(mesh): + mesh_lst = mesh.split(only_watertight=False) + keep_mesh = mesh_lst[0] + for mesh in mesh_lst: + if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]: + keep_mesh = mesh + return keep_mesh + + +def poisson(mesh, obj_path): + mesh.export(obj_path) + ms = pymeshlab.MeshSet() + ms.load_new_mesh(obj_path) + ms.set_verbosity(False) + ms.generate_surface_reconstruction_screened_poisson(depth=10, preclean=True) + ms.set_current_mesh(1) + ms.save_current_mesh(obj_path) + + new_meshes = trimesh.load(obj_path) + new_mesh_lst = new_meshes.split(only_watertight=False) + comp_num = [new_mesh.vertices.shape[0] for new_mesh in new_mesh_lst] + + return new_mesh_lst[comp_num.index(max(comp_num))] + + +def mesh_clean(mesh, save_path=None): + """ clean mesh """ + cc = mesh.split(only_watertight=False) + out_mesh = cc[0] + bbox = out_mesh.bounds + height = bbox[1, 0] - bbox[0, 0] + for c in cc: + bbox = c.bounds + if height < bbox[1, 0] - bbox[0, 0]: + height = bbox[1, 0] - bbox[0, 0] + out_mesh = c + if save_path is not None: + out_mesh.export(save_path) + return out_mesh + + +def normalize_vert(vertices, return_cs=False): + if isinstance(vertices, np.ndarray): + vmax, vmin = vertices.max(0), vertices.min(0) + center = (vmax + vmin) * 0.5 + scale = 1. / np.max(vmax - vmin) + else: # torch.tensor + vmax, vmin = vertices.max(0)[0], vertices.min(0)[0] + center = (vmax + vmin) * 0.5 + scale = 1. / torch.max(vmax - vmin) + if return_cs: + return (vertices - center) * scale, center, scale + return (vertices - center) * scale diff --git a/TADA/lib/common/optimizer.py b/TADA/lib/common/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f5bb64fc642fe525ca1c98f3c2dfa666c3143851 --- /dev/null +++ b/TADA/lib/common/optimizer.py @@ -0,0 +1,325 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +class Adan(Optimizer): + """ + Implements a pytorch variant of Adan + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for + Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for + first- and second-order moments. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + no_prox (bool): how to perform the decoupled weight decay + (default: False) + foreach (bool): if True would use torch._foreach implementation. + It's faster but uses slightly more memory. (default: True) + """ + def __init__(self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=0.0, + no_prox=False, + foreach: bool = True): + if not 0.0 <= max_grad_norm: + raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError('Invalid beta parameter at index 2: {}'.format( + betas[2])) + defaults = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + foreach=foreach) + super().__init__(params, defaults) + + def __setstate__(self, state): + super(Adan, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('no_prox', False) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step.""" + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if self.defaults['max_grad_norm'] > 0: + device = self.param_groups[0]['params'][0].device + global_grad_norm = torch.zeros(1, device=device) + + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], + device=device) + for group in self.param_groups: + + for p in group['params']: + if p.grad is not None: + grad = p.grad + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + + clip_global_grad_norm = torch.clamp( + max_grad_norm / (global_grad_norm + group['eps']), + max=1.0).item() + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_diffs = [] + neg_pre_grads = [] + + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support + # by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1**group['step'] + bias_correction2 = 1.0 - beta2**group['step'] + bias_correction3 = 1.0 - beta3**group['step'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + + if 'neg_pre_grad' not in state or group['step'] == 1: + state['neg_pre_grad'] = p.grad.clone().mul_( + -clip_global_grad_norm) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + exp_avg_diffs.append(state['exp_avg_diff']) + neg_pre_grads.append(state['neg_pre_grad']) + + kwargs = dict( + params=params_with_grad, + grads=grads, + exp_avgs=exp_avgs, + exp_avg_sqs=exp_avg_sqs, + exp_avg_diffs=exp_avg_diffs, + neg_pre_grads=neg_pre_grads, + beta1=beta1, + beta2=beta2, + beta3=beta3, + bias_correction1=bias_correction1, + bias_correction2=bias_correction2, + bias_correction3_sqrt=math.sqrt(bias_correction3), + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + no_prox=group['no_prox'], + clip_global_grad_norm=clip_global_grad_norm, + ) + + if group['foreach']: + _multi_tensor_adan(**kwargs) + else: + _single_tensor_adan(**kwargs) + + return loss + + +def _single_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_diff = exp_avg_diffs[i] + neg_grad_or_diff = neg_pre_grads[i] + + grad.mul_(clip_global_grad_norm) + + # for memory saving, we use `neg_grad_or_diff` + # to get some temp variable in a inplace way + neg_grad_or_diff.add_(grad) + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t + exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff, + alpha=1 - beta2) # diff_t + + neg_grad_or_diff.mul_(beta2).add_(grad) + exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff, + neg_grad_or_diff, + value=1 - beta3) # n_t + + denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps) + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + param.mul_(1 - lr * weight_decay) + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + else: + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + param.div_(1 + lr * weight_decay) + + neg_grad_or_diff.zero_().add_(grad, alpha=-1.0) + + +def _multi_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + if len(params) == 0: + return + + torch._foreach_mul_(grads, clip_global_grad_norm) + + # for memory saving, we use `neg_pre_grads` + # to get some temp variable in a inplace way + torch._foreach_add_(neg_pre_grads, grads) + + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t + + torch._foreach_mul_(exp_avg_diffs, beta2) + torch._foreach_add_(exp_avg_diffs, neg_pre_grads, + alpha=1 - beta2) # diff_t + + torch._foreach_mul_(neg_pre_grads, beta2) + torch._foreach_add_(neg_pre_grads, grads) + torch._foreach_mul_(exp_avg_sqs, beta3) + torch._foreach_addcmul_(exp_avg_sqs, + neg_pre_grads, + neg_pre_grads, + value=1 - beta3) # n_t + + denom = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(denom, bias_correction3_sqrt) + torch._foreach_add_(denom, eps) + + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + torch._foreach_mul_(params, 1 - lr * weight_decay) + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + else: + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + torch._foreach_div_(params, 1 + lr * weight_decay) + torch._foreach_zero_(neg_pre_grads) + torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0) \ No newline at end of file diff --git a/TADA/lib/common/plot.py b/TADA/lib/common/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..09ad20deab6b4c413eea189c8c3061f443a54ad0 --- /dev/null +++ b/TADA/lib/common/plot.py @@ -0,0 +1,44 @@ +import mediapipe as mp +from mediapipe import solutions +from mediapipe.framework.formats import landmark_pb2 +import numpy as np + + +def draw_mediapipe_landmarks_on_image(rgb_image: object, face_landmarks_list: object) -> object: + annotated_image = np.copy(rgb_image) + + # Loop through the detected faces to visualize. + for idx in range(len(face_landmarks_list)): + face_landmarks = face_landmarks_list[idx] + + # Draw the face landmarks. + face_landmarks_proto = landmark_pb2.NormalizedLandmarkList() + face_landmarks_proto.landmark.extend([ + landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in face_landmarks + ]) + + # print(len(mp.solutions.face_mesh.FACEMESH_TESSELATION)) + # exit() + + solutions.drawing_utils.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks_proto, + connections=mp.solutions.face_mesh.FACEMESH_TESSELATION, + # connections=FACEMESH_NOSE, + landmark_drawing_spec=None, + connection_drawing_spec=mp.solutions.drawing_styles.get_default_face_mesh_tesselation_style() + ) + solutions.drawing_utils.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks_proto, + connections=mp.solutions.face_mesh.FACEMESH_CONTOURS, + landmark_drawing_spec=None, + connection_drawing_spec=mp.solutions.drawing_styles.get_default_face_mesh_contours_style()) + solutions.drawing_utils.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks_proto, + connections=mp.solutions.face_mesh.FACEMESH_IRISES, + landmark_drawing_spec=None, + connection_drawing_spec=mp.solutions.drawing_styles.get_default_face_mesh_iris_connections_style()) + + return annotated_image diff --git a/TADA/lib/common/pytorch3d_renderer.py b/TADA/lib/common/pytorch3d_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..3716dc6d9a9499562c1cb8b0bc320bb8a9cd9fb0 --- /dev/null +++ b/TADA/lib/common/pytorch3d_renderer.py @@ -0,0 +1,282 @@ +from pytorch3d.renderer import ( + BlendParams, blending, look_at_view_transform, FoVOrthographicCameras, + PointLights, RasterizationSettings, PointsRasterizationSettings, + PointsRenderer, AlphaCompositor, PointsRasterizer, MeshRenderer, + MeshRasterizer, SoftPhongShader, SoftSilhouetteShader, TexturesVertex) +from pytorch3d.renderer.mesh import TexturesVertex +from pytorch3d.structures import Meshes, Pointclouds + +import torch +import numpy as np +import math +import cv2 + + +class cleanShader(torch.nn.Module): + def __init__(self, device="cpu", cameras=None, blend_params=None): + super().__init__() + self.cameras = cameras + self.blend_params = blend_params if blend_params is not None else BlendParams( + ) + + def forward(self, fragments, meshes, **kwargs): + cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of TexturedSoftPhongShader" + + raise ValueError(msg) + + # get renderer output + blend_params = kwargs.get("blend_params", self.blend_params) + texels = meshes.sample_textures(fragments) + images = blending.softmax_rgb_blend(texels, + fragments, + blend_params, + znear=-256, + zfar=256) + + return images + + +class Render: + def __init__(self, size=512, device=torch.device("cuda:0")): + self.device = device + self.mesh_y_center = 100.0 + self.dis = 100.0 + self.scale = 1.0 + self.size = size + self.cam_pos = [(0, 100, 100)] + + self.mesh = None + self.pcd = None + self.renderer = None + self.meshRas = None + + def get_camera(self, cam_id): + # at + + R, T = look_at_view_transform(eye=[self.cam_pos[cam_id]], + at=((0, self.mesh_y_center, 0),), + up=((0, 1, 0),)) + + camera = FoVOrthographicCameras(device=self.device, + R=R, + T=T, + znear=100.0, + zfar=-100.0, + max_y=100.0, + min_y=-100.0, + max_x=100.0, + min_x=-100.0, + scale_xyz=(self.scale * np.ones(3),)) + return camera + + def init_renderer(self, camera, type='clean_mesh', bg='gray'): + + if 'mesh' in type: + # rasterizer + self.raster_settings_mesh = RasterizationSettings( + image_size=self.size, + blur_radius=np.log(1.0 / 1e-4) * 1e-7, + faces_per_pixel=30, + ) + self.meshRas = MeshRasterizer(cameras=camera, + raster_settings=self.raster_settings_mesh) + + if bg == 'black': + blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0)) + elif bg == 'white': + blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0)) + elif bg == 'gray': + blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5)) + + if type == 'ori_mesh': + lights = PointLights(device=self.device, + ambient_color=((0.8, 0.8, 0.8),), + diffuse_color=((0.2, 0.2, 0.2),), + specular_color=((0.0, 0.0, 0.0),), + location=[[0.0, 200.0, 200.0]]) + self.renderer = MeshRenderer( + rasterizer=self.meshRas, + shader=SoftPhongShader( + device=self.device, + cameras=camera, + lights=lights, + blend_params=blendparam)) + + if type == 'silhouette': + self.raster_settings_silhouette = RasterizationSettings( + image_size=self.size, + blur_radius=np.log(1. / 1e-4 - 1.) * 5e-5, + faces_per_pixel=50, + cull_backfaces=True, + ) + + self.silhouetteRas = MeshRasterizer( + cameras=camera, raster_settings=self.raster_settings_silhouette) + self.renderer = MeshRenderer(rasterizer=self.silhouetteRas, + shader=SoftSilhouetteShader()) + + if type == 'pointcloud': + self.raster_settings_pcd = PointsRasterizationSettings( + image_size=self.size, + radius=0.006, + points_per_pixel=10) + + self.pcdRas = PointsRasterizer(cameras=camera, + raster_settings=self.raster_settings_pcd) + self.renderer = PointsRenderer( + rasterizer=self.pcdRas, + compositor=AlphaCompositor(background_color=(0, 0, 0))) + + if type == 'clean_mesh': + self.renderer = MeshRenderer( + rasterizer=self.meshRas, + shader=cleanShader( + device=self.device, + cameras=camera, + blend_params=blendparam)) + + def set_camera(self, verts, normalize=False): + self.scale = 100 + self.mesh_y_center = 0 + if normalize: + y_max = verts.max(dim=1)[0][0, 1].item() + y_min = verts.min(dim=1)[0][0, 1].item() + self.scale *= 0.95 / ((y_max - y_min) * 0.5 + 1e-10) + self.mesh_y_center = (y_max + y_min) * 0.5 + + self.cam_pos = [(0, self.mesh_y_center, self.dis), + (self.dis, self.mesh_y_center, 0), + (0, self.mesh_y_center, -self.dis), + (-self.dis, self.mesh_y_center, 0)] + + def load_mesh(self, verts, faces, verts_rgb=None, normalize=False, use_normal=False): + """load mesh into the pytorch3d renderer + + Args: + verts ([N,3]): verts + faces ([N,3]): faces + verts_rgb ([N,3]): rgb + normalize: bool + """ + + if not torch.is_tensor(verts): + verts = torch.tensor(verts) + if not torch.is_tensor(faces): + faces = torch.tensor(faces) + + if verts.ndimension() == 2: + verts = verts.unsqueeze(0).float() + if faces.ndimension() == 2: + faces = faces.unsqueeze(0).long() + + verts = verts.to(self.device) + faces = faces.to(self.device) + self.set_camera(verts, normalize) + self.mesh = Meshes(verts, faces).to(self.device) + + if verts_rgb is not None: + if not torch.is_tensor(verts_rgb): + verts_rgb = torch.as_tensor(verts_rgb) + if verts_rgb.ndimension() == 2: + verts_rgb = verts_rgb.unsqueeze(0).float() + verts_rgb = verts_rgb.to(self.device) + elif use_normal: + verts_rgb = self.mesh.verts_normals_padded() + verts_rgb = (verts_rgb + 1.0) * 0.5 + else: + verts_rgb = self.mesh.verts_normals_padded()[..., 2:3].expand(-1, -1, 3) + verts_rgb = (verts_rgb + 1.0) * 0.5 + textures = TexturesVertex(verts_features=verts_rgb) + self.mesh.textures = textures + return self.mesh + + def load_pcd(self, verts, verts_rgb, normalize=False): + """load pointcloud into the pytorch3d renderer + + Args: + verts ([B, N,3]): verts + verts_rgb ([B, N,3]): verts colors + normalize bool: render point cloud in center + """ + assert verts.shape == verts_rgb.shape and len(verts.shape) == 3 + # data format convert + if not torch.is_tensor(verts): + verts = torch.as_tensor(verts) + if not torch.is_tensor(verts_rgb): + verts_rgb = torch.as_tensor(verts_rgb) + + verts = verts.float().to(self.device) + verts_rgb = verts_rgb.float().to(self.device) + + # camera setting + self.set_camera(verts, normalize) + pcd = Pointclouds(points=verts, features=verts_rgb).to(self.device) + return pcd + + def get_image(self, cam_ids=[0, 2], type='clean_mesh', bg='gray'): + images = [] + for cam_id in range(len(self.cam_pos)): + if cam_id in cam_ids: + self.init_renderer(self.get_camera(cam_id), type, bg) + rendered_img = self.renderer(self.mesh)[0, :, :, :3] + if cam_id == 2 and len(cam_ids) == 2: + rendered_img = torch.flip(rendered_img, dims=[1]) + images.append(rendered_img) + images = torch.cat(images, 1) + return images.detach().cpu().numpy() + + def get_clean_image(self, cam_ids=[0, 2], type='clean_mesh', bg='gray'): + images = [] + for cam_id in range(len(self.cam_pos)): + if cam_id in cam_ids: + self.init_renderer(self.get_camera(cam_id), type, bg) + rendered_img = self.renderer(self.mesh)[0:1, :, :, :3] + if cam_id == 2 and len(cam_ids) == 2: + rendered_img = torch.flip(rendered_img, dims=[2]) + images.append(rendered_img) + return images + + def get_silhouette_image(self, cam_ids=[0, 2]): + images = [] + for cam_id in range(len(self.cam_pos)): + if cam_id in cam_ids: + self.init_renderer(self.get_camera(cam_id), 'silhouette') + rendered_img = self.renderer(self.mesh)[0:1, :, :, 3] + if cam_id == 2 and len(cam_ids) == 2: + rendered_img = torch.flip(rendered_img, dims=[2]) + images.append(rendered_img) + + return images + + def get_image_pcd(self, pcd, cam_ids=[0, 1, 2, 3]): + images = torch.zeros((self.size, self.size * len(cam_ids), 3)).to(self.device) + for i, cam_id in enumerate(cam_ids): + self.init_renderer(self.get_camera(cam_id), 'pointcloud') + images[:, self.size * i:self.size * (i + 1), :] = self.renderer(pcd)[0, :, :, :3] + + return images.cpu().numpy() + + def get_rendered_video(self, save_path, num_angle=100, s=0): + self.cam_pos = [] + + interval = 360. / num_angle + for i in range(num_angle): + # for angle in range(90, 90+360, ): + angle = (s + i * interval) % 360 + self.cam_pos.append( + (self.dis * math.cos(np.pi / 180 * angle), self.mesh_y_center, + self.dis * math.sin(np.pi / 180 * angle))) + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video = cv2.VideoWriter(save_path, fourcc, 30, (self.size, self.size)) + + for cam_id in range(len(self.cam_pos)): + self.init_renderer(self.get_camera(cam_id), 'clean_mesh', 'gray') + rendered_img = (self.renderer(self.mesh)[0, :, :, :3] * 255.0).detach().cpu().numpy().astype(np.uint8) + + video.write(rendered_img) + + video.release() diff --git a/TADA/lib/common/remesh.py b/TADA/lib/common/remesh.py new file mode 100644 index 0000000000000000000000000000000000000000..978216d59f071c0799c3378d679920e40b0fcd6c --- /dev/null +++ b/TADA/lib/common/remesh.py @@ -0,0 +1,88 @@ +import trimesh +import torch +import numpy as np + +def subdivide(vertices, faces, attributes=None, face_index=None): + """ + Subdivide a mesh into smaller triangles. + + Note that if `face_index` is passed, only those faces will + be subdivided and their neighbors won't be modified making + the mesh no longer "watertight." + + Parameters + ---------- + vertices : (n, 3) float + Vertices in space + faces : (n, 3) int + Indexes of vertices which make up triangular faces + attributes: (n, d) float + vertices attributes + face_index : faces to subdivide. + if None: all faces of mesh will be subdivided + if (n,) int array of indices: only specified faces + + Returns + ---------- + new_vertices : (n, 3) float + Vertices in space + new_faces : (n, 3) int + Remeshed faces + """ + if face_index is None: + face_index = np.arange(len(faces)) + else: + face_index = np.asanyarray(face_index) + + # the (c,3) int set of vertex indices + faces = faces[face_index] + # the (c, 3, 3) float set of points in the triangles + triangles = vertices[faces] + # the 3 midpoints of each triangle edge + # stacked to a (3 * c, 3) float + mid = np.vstack([triangles[:, g, :].mean(axis=1) for g in [[0, 1], [1, 2], [2, 0]]]) + + # for adjacent faces we are going to be generating + # the same midpoint twice so merge them here + mid_idx = (np.arange(len(face_index) * 3)).reshape((3, -1)).T + unique, inverse = trimesh.grouping.unique_rows(mid) + mid = mid[unique] + mid_idx = inverse[mid_idx] + len(vertices) + + # the new faces with correct winding + f = np.column_stack([faces[:, 0], + mid_idx[:, 0], + mid_idx[:, 2], + mid_idx[:, 0], + faces[:, 1], + mid_idx[:, 1], + mid_idx[:, 2], + mid_idx[:, 1], + faces[:, 2], + mid_idx[:, 0], + mid_idx[:, 1], + mid_idx[:, 2]]).reshape((-1, 3)) + # add the 3 new faces per old face + new_faces = np.vstack((faces, f[len(face_index):])) + # replace the old face with a smaller face + new_faces[face_index] = f[:len(face_index)] + + new_vertices = np.vstack((vertices, mid)) + + if attributes is not None: + tri_att = attributes[faces] + mid_att = np.vstack([tri_att[:, g, :].mean(axis=1) for g in [[0, 1], [1, 2], [2, 0]]]) + mid_att = mid_att[unique] + new_attributes = np.vstack((attributes, mid_att)) + return new_vertices, new_faces, new_attributes, unique + + return new_vertices, new_faces, unique + + +def subdivide_inorder(vertices, faces, unique): + triangles = vertices[faces] + mid = torch.vstack([triangles[:, g, :].mean(1) for g in [[0, 1], [1, 2], [2, 0]]]) + + mid = mid[unique] + new_vertices = torch.vstack((vertices, mid)) + return new_vertices \ No newline at end of file diff --git a/TADA/lib/common/renderer.py b/TADA/lib/common/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..c6e8bf4b4cc2d8c3aeacb949658b355522833e7d --- /dev/null +++ b/TADA/lib/common/renderer.py @@ -0,0 +1,132 @@ +import random + +import torch +import torch.nn.functional as F +import nvdiffrast.torch as dr +from . import utils +from lib.common.obj import compute_normal + + +class Renderer(torch.nn.Module): + def __init__(self): + super().__init__() + # self.glctx = dr.RasterizeCudaContext() + # self.glctx = dr.RasterizeGLContext() + try: + self.glctx = dr.RasterizeCudaContext() + except: + self.glctx = dr.RasterizeGLContext() + + def forward(self, mesh, mvp, + h=512, + w=512, + light_d=None, + ambient_ratio=1., + shading='albedo', + spp=1, + mlp_texture=None, + is_train=False): + """ + Args: + spp: + return_normal: + transform_nml: + mesh: Mesh object + mvp: [batch, 4, 4] + h: int + w: int + light_d: + ambient_ratio: float + shading: str shading type albedo, normal, + ssp: int + Returns: + color: [batch, h, w, 3] + alpha: [batch, h, w, 1] + depth: [batch, h, w, 1] + + """ + B = mvp.shape[0] + v_clip = torch.bmm(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).expand(B, -1, -1), + torch.transpose(mvp, 1, 2)).float() # [B, N, 4] + + res = (int(h * spp), int(w * spp)) if spp > 1 else (h, w) + rast, rast_db = dr.rasterize(self.glctx, v_clip, mesh.f, res) + + ################################################################################ + # Interpolate attributes + ################################################################################ + + # Interpolate world space position + alpha, _ = dr.interpolate(torch.ones_like(v_clip[..., :1]), rast, mesh.f) # [B, H, W, 1] + depth = rast[..., [2]] # [B, H, W] + + if is_train: + vn, _ = compute_normal(v_clip[0, :, :3], mesh.f) + normal, _ = dr.interpolate(vn[None, ...].float(), rast, mesh.f) + else: + normal, _ = dr.interpolate(mesh.vn[None, ...].float(), rast, mesh.f) + + # Texture coordinate + if not shading == 'normal': + if mlp_texture is not None: + albedo = self.get_mlp_texture(mesh, mlp_texture, rast, rast_db) + else: + albedo = self.get_2d_texture(mesh, rast, rast_db) + + if shading == 'normal': + color = (normal + 1) / 2. + elif shading == 'albedo': + color = albedo + else: # lambertian + lambertian = ambient_ratio + (1 - ambient_ratio) * (normal @ light_d.view(-1, 1)).float().clamp(min=0) + color = albedo * lambertian.repeat(1, 1, 1, 3) + + normal = (normal + 1) / 2. + + normal = dr.antialias(normal, rast, v_clip, mesh.f).clamp(0, 1) # [H, W, 3] + color = dr.antialias(color, rast, v_clip, mesh.f).clamp(0, 1) # [H, W, 3] + alpha = dr.antialias(alpha, rast, v_clip, mesh.f).clamp(0, 1) # [H, W, 3] + + # inverse super-sampling + if spp > 1: + color = utils.scale_img_nhwc(color, (h, w)) + alpha = utils.scale_img_nhwc(alpha, (h, w)) + normal = utils.scale_img_nhwc(normal, (h, w)) + + return color, normal, alpha + + def get_mlp_texture(self, mesh, mlp_texture, rast, rast_db, res=2048): + # uv = mesh.vt[None, ...] * 2.0 - 1.0 + uv = mesh.vt[None, ...] + + # pad to four component coordinate + uv4 = torch.cat((uv, torch.zeros_like(uv[..., 0:1]), torch.ones_like(uv[..., 0:1])), dim=-1) + + # rasterize + _rast, _ = dr.rasterize(self.glctx, uv4, mesh.f.int(), (res, res)) + print("_rast ", _rast.shape) + # Interpolate world space position + # gb_pos, _ = dr.interpolate(mesh.v[None, ...], _rast, mesh.f.int()) + + # Sample out textures from MLP + tex = mlp_texture.sample(_rast[..., :-1].view(-1, 3)).view(*_rast.shape[:-1], 3) + + texc, texc_db = dr.interpolate(mesh.vt[None, ...], rast, mesh.ft, rast_db=rast_db, diff_attrs='all') + print(tex.shape) + + albedo = dr.texture( + tex, texc, uv_da=texc_db, filter_mode='linear-mipmap-linear') # [B, H, W, 3] + # albedo = torch.where(rast[..., 3:] > 0, albedo, torch.tensor(0).to(albedo.device)) # remove background + + # print(tex.shape, albedo.shape) + # exit() + return albedo + + @staticmethod + def get_2d_texture(mesh, rast, rast_db): + texc, texc_db = dr.interpolate(mesh.vt[None, ...], rast, mesh.ft, rast_db=rast_db, diff_attrs='all') + + albedo = dr.texture( + mesh.albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode='linear-mipmap-linear') # [B, H, W, 3] + albedo = torch.where(rast[..., 3:] > 0, albedo, torch.tensor(0).to(albedo.device)) # remove background + return albedo diff --git a/TADA/lib/common/rotation_conversions.py b/TADA/lib/common/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..210ae1f0878b3ab223ec3d51d4053751dceb47ff --- /dev/null +++ b/TADA/lib/common/rotation_conversions.py @@ -0,0 +1,552 @@ +# This code is based on https://github.com/Mathux/ACTOR.git +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in dataset as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/TADA/lib/common/utils.py b/TADA/lib/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9600e574f423636b8d204470ce2b431f2d465878 --- /dev/null +++ b/TADA/lib/common/utils.py @@ -0,0 +1,330 @@ +import os +import random +import json +import pickle as pkl +import cv2 +import numpy as np +import imageio +import torch +from packaging import version as pver + +from yacs.config import CfgNode as CN + + +def load_config(path, default_path=None): + cfg = CN(new_allowed=True) + if default_path is not None: + cfg.merge_from_file(default_path) + cfg.merge_from_file(path) + + return cfg + + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x * y, -1, keepdim=True) + + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + + +def plot_grid_images(images, row, col, save_path=None): + """ + Args: + images: np.array [B, H, W, 3] + row: + col: + save_path: + + Returns: + + """ + assert row * col == images.shape[0] + images = np.vstack([np.hstack(images[r * col:(r + 1) * col]) for r in range(row)]) + if save_path: + cv2.imwrite(save_path, images * 255) + return images + + +def safe_normalize(x, eps=1e-20): + return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = True + + +def torch_vis_2d(x, renormalize=False): + # x: [3, H, W], [H, W, 3] or [1, H, W] or [H, W] + import matplotlib.pyplot as plt + import numpy as np + import torch + + if isinstance(x, torch.Tensor): + if len(x.shape) == 3 and x.shape[0] == 3: + x = x.permute(1, 2, 0).squeeze() + x = x.detach().cpu().numpy() + + print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}') + + x = x.astype(np.float32) + + # renormalize + if renormalize: + x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8) + + plt.imshow(x) + plt.show() + + +@torch.cuda.amp.autocast(enabled=False) +def get_rays(poses, intrinsics, H, W, N=-1, error_map=None): + ''' get rays + Args: + poses: [B, 4, 4], cam2world + intrinsics: [4] + H, W, N: int + error_map: [B, 128 * 128], sample probability based on training error + Returns: + rays_o, rays_d: [B, N, 3] + inds: [B, N] + ''' + + device = poses.device + B = poses.shape[0] + fx, fy, cx, cy = intrinsics + + i, j = custom_meshgrid(torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device)) + i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 + j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 + + results = {} + + if N > 0: + N = min(N, H * W) + + if error_map is None: + inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate + inds = inds.expand([B, N]) + else: + + # weighted sample on a low-reso grid + inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128) + + # map to the original resolution with random perturb. + inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway. + sx, sy = H / 128, W / 128 + inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1) + inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1) + inds = inds_x * W + inds_y + + results['inds_coarse'] = inds_coarse # need this when updating error_map + + i = torch.gather(i, -1, inds) + j = torch.gather(j, -1, inds) + + results['inds'] = inds + + else: + inds = torch.arange(H * W, device=device).expand([B, H * W]) + + zs = - torch.ones_like(i) + xs = - (i - cx) / fx * zs + ys = (j - cy) / fy * zs + directions = torch.stack((xs, ys, zs), dim=-1) + # directions = safe_normalize(directions) + rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) + + rays_o = poses[..., :3, 3] # [B, 3] + rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] + + results['rays_o'] = rays_o + results['rays_d'] = rays_d + + return rays_o, rays_d + + +def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'): + assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[ + 1]), "Trying to magnify image in one dimension and minify in the other" + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger + y = torch.nn.functional.interpolate(y, size, mode=min) + else: # Magnification + if mag == 'bilinear' or mag == 'bicubic': + y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) + else: + y = torch.nn.functional.interpolate(y, size, mode=mag) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + + +def scale_img_hwc(x, size, mag='bilinear', min='bilinear'): + return scale_img_nhwc(x[None, ...], size, mag, min)[0] + + +def scale_img_nhw(x, size, mag='bilinear', min='bilinear'): + return scale_img_nhwc(x[..., None], size, mag, min)[..., 0] + + +def scale_img_hw(x, size, mag='bilinear', min='bilinear'): + return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0] + + +def trunc_rev_sigmoid(x, eps=1e-6): + x = x.clamp(eps, 1 - eps) + return torch.log(x / (1 - x)) + + +def save_image(fn, x: np.ndarray): + try: + if os.path.splitext(fn)[1] == ".png": + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), + compress_level=3) # Low compression for faster saving + else: + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) + except: + print("WARNING: FAILED to save image %s" % fn) + + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1 / (y * aspect), 0, 0, 0], + [0, 1 / -y, 0, 0], + [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)], + [0, 0, -1, 0]], dtype=torch.float32, device=device) + + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c, s, 0], + [0, -s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[c, 0, s, 0], + [0, 1, 0, 0], + [-s, 0, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + + +def make_rotate(rx, ry, rz): + sinX = np.sin(rx) + sinY = np.sin(ry) + sinZ = np.sin(rz) + + cosX = np.cos(rx) + cosY = np.cos(ry) + cosZ = np.cos(rz) + + Rx = np.zeros((3, 3)) + Rx[0, 0] = 1.0 + Rx[1, 1] = cosX + Rx[1, 2] = -sinX + Rx[2, 1] = sinX + Rx[2, 2] = cosX + + Ry = np.zeros((3, 3)) + Ry[0, 0] = cosY + Ry[0, 2] = sinY + Ry[1, 1] = 1.0 + Ry[2, 0] = -sinY + Ry[2, 2] = cosY + + Rz = np.zeros((3, 3)) + Rz[0, 0] = cosZ + Rz[0, 1] = -sinZ + Rz[1, 0] = sinZ + Rz[1, 1] = cosZ + Rz[2, 2] = 1.0 + + R = np.matmul(np.matmul(Rz, Ry), Rx) + return R + + +class SMPLXSeg: + def __init__(self, base_dir): + smplx_dir = os.path.join(base_dir, "smplx") + smplx_segs = json.load(open(f"{smplx_dir}/smplx_vert_segementation.json")) + flame_segs = pkl.load(open(f"{smplx_dir}/FLAME_masks.pkl", "rb"), encoding='latin1') + smplx_face = np.load(f"{smplx_dir}/smplx_faces.npy") + + smplx_flame_vid = np.load(f"{smplx_dir}/FLAME_SMPLX_vertex_ids.npy", allow_pickle=True) + + self.eyeball_ids = smplx_segs["leftEye"] + smplx_segs["rightEye"] + self.hands_ids = smplx_segs["leftHand"] + smplx_segs["rightHand"] + \ + smplx_segs["leftHandIndex1"] + smplx_segs["rightHandIndex1"] + self.neck_ids = smplx_segs["neck"] + self.head_ids = smplx_segs["head"] + + self.front_face_ids = list(smplx_flame_vid[flame_segs["face"]]) + self.ears_ids = list(smplx_flame_vid[flame_segs["left_ear"]]) + list(smplx_flame_vid[flame_segs["right_ear"]]) + self.forehead_ids = list(smplx_flame_vid[flame_segs["forehead"]]) + self.lips_ids = list(smplx_flame_vid[flame_segs["lips"]]) + self.nose_ids = list(smplx_flame_vid[flame_segs["nose"]]) + self.eyes_ids = list(smplx_flame_vid[flame_segs["right_eye_region"]]) + list( + smplx_flame_vid[flame_segs["left_eye_region"]]) + + # re-mesh mask + remesh_ids = list(set(self.front_face_ids) - set(self.forehead_ids)) + self.ears_ids + self.eyeball_ids + self.hands_ids + remesh_mask = ~np.isin(np.arange(10475), remesh_ids) + self.remesh_mask = remesh_mask[smplx_face].all(axis=1) + + +def create_checkerboard(h, w, c, grid_size): + num_grid_row = h // grid_size + num_grid_col = w // grid_size + grid_ones = np.ones((grid_size, grid_size, c)) + grid_zeros = np.zeros((grid_size, grid_size, c)) + + checkerboard = np.vstack([ + np.hstack([grid_ones if (c + r) % 2 == 1 else grid_zeros for c in range(num_grid_col)]) + for r in range(num_grid_row) + ]) + + # pad + cx, cy, _ = checkerboard.shape + out = np.ones((h, w, c)) + dx = (h - cx) // 2 + dy = (w - cy) // 2 + out[dx:dx + cx, dy:dy + cy] = checkerboard + return out + + +if __name__ == '__main__': + out = create_checkerboard(512, 512, 3, 64) + import cv2 + + cv2.imwrite("ck.png", out * 255) diff --git a/TADA/lib/common/visual.py b/TADA/lib/common/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..3c6c4687ec73a9f2743871031b59907630e4185a --- /dev/null +++ b/TADA/lib/common/visual.py @@ -0,0 +1,57 @@ +import cv2 + +import mediapipe as mp +from mediapipe import solutions +from mediapipe.framework.formats import landmark_pb2 +import numpy as np + + +def draw_mediapipe_landmarks(rgb_image: object, face_landmarks_list: object) -> object: + annotated_image = np.copy(rgb_image) + + # Loop through the detected faces to visualize. + for idx in range(len(face_landmarks_list)): + face_landmarks = face_landmarks_list[idx] + + # Draw the face landmarks. + face_landmarks_proto = landmark_pb2.NormalizedLandmarkList() + face_landmarks_proto.landmark.extend([ + landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in face_landmarks + ]) + + # print(len(mp.solutions.face_mesh.FACEMESH_TESSELATION)) + # exit() + + solutions.drawing_utils.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks_proto, + connections=mp.solutions.face_mesh.FACEMESH_TESSELATION, + # connections=FACEMESH_NOSE, + landmark_drawing_spec=None, + connection_drawing_spec=mp.solutions.drawing_styles.get_default_face_mesh_tesselation_style() + ) + solutions.drawing_utils.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks_proto, + connections=mp.solutions.face_mesh.FACEMESH_CONTOURS, + landmark_drawing_spec=None, + connection_drawing_spec=mp.solutions.drawing_styles.get_default_face_mesh_contours_style()) + solutions.drawing_utils.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks_proto, + connections=mp.solutions.face_mesh.FACEMESH_IRISES, + landmark_drawing_spec=None, + connection_drawing_spec=mp.solutions.drawing_styles.get_default_face_mesh_iris_connections_style()) + + return annotated_image + + +def draw_landmarks(canvas, landmarks, eps=1e-4, fill=(0, 0, 0), thickness=-1): + h, w, c = canvas.shape + for lmk in landmarks: + x, y = lmk + x = int(x * w) + y = int(y * h) + if eps < x <= w and eps < y <= h: + cv2.circle(canvas, (x, y), 3, fill, thickness=thickness) + return canvas diff --git a/TADA/smplx/__init__.py b/TADA/smplx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..886949df670691d1ef5995737cafa285224826c4 --- /dev/null +++ b/TADA/smplx/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from .body_models import ( + create, + SMPL, + SMPLH, + SMPLX, + MANO, + FLAME, + build_layer, + SMPLLayer, + SMPLHLayer, + SMPLXLayer, + MANOLayer, + FLAMELayer, +) diff --git a/TADA/smplx/__pycache__/__init__.cpython-39.pyc b/TADA/smplx/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9990ece4b550bfc445832146fa5365d095cd1b50 Binary files /dev/null and b/TADA/smplx/__pycache__/__init__.cpython-39.pyc differ diff --git a/TADA/smplx/__pycache__/body_models.cpython-39.pyc b/TADA/smplx/__pycache__/body_models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9140e575792f0e8ab384a175d2fd0c4e36e03de1 Binary files /dev/null and b/TADA/smplx/__pycache__/body_models.cpython-39.pyc differ diff --git a/TADA/smplx/__pycache__/lbs.cpython-39.pyc b/TADA/smplx/__pycache__/lbs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57a3f846a67271e4b0a01f7b06462c3a4ddc30ce Binary files /dev/null and b/TADA/smplx/__pycache__/lbs.cpython-39.pyc differ diff --git a/TADA/smplx/__pycache__/utils.cpython-39.pyc b/TADA/smplx/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0df64635d800aa25c1e4a8fdb2fe0c10d2242fe7 Binary files /dev/null and b/TADA/smplx/__pycache__/utils.cpython-39.pyc differ diff --git a/TADA/smplx/__pycache__/vertex_ids.cpython-39.pyc b/TADA/smplx/__pycache__/vertex_ids.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e79e3ce2f16d4c58745587e3b748a08cac8dcf12 Binary files /dev/null and b/TADA/smplx/__pycache__/vertex_ids.cpython-39.pyc differ diff --git a/TADA/smplx/__pycache__/vertex_joint_selector.cpython-39.pyc b/TADA/smplx/__pycache__/vertex_joint_selector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52c4a2d2441f91453c743ee5ca757ee1e2eb51bf Binary files /dev/null and b/TADA/smplx/__pycache__/vertex_joint_selector.cpython-39.pyc differ diff --git a/TADA/smplx/body_models.py b/TADA/smplx/body_models.py new file mode 100644 index 0000000000000000000000000000000000000000..474fd8ac756fec9f38f95fb6b1c9f87842fe73a3 --- /dev/null +++ b/TADA/smplx/body_models.py @@ -0,0 +1,2415 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional, Dict, Union +import os +import os.path as osp + +import pickle + +import numpy as np + +import torch +import torch.nn as nn + +from .lbs import ( + lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords, blend_shapes) + +from .vertex_ids import vertex_ids as VERTEX_IDS +from .utils import ( + Struct, to_np, to_tensor, Tensor, Array, + SMPLOutput, + SMPLHOutput, + SMPLXOutput, + MANOOutput, + FLAMEOutput, + find_joint_kin_chain) +from .vertex_joint_selector import VertexJointSelector + + +class SMPL(nn.Module): + + NUM_JOINTS = 23 + NUM_BODY_JOINTS = 23 + SHAPE_SPACE_DIM = 300 + + def __init__( + self, model_path: str, + kid_template_path: str = '', + data_struct: Optional[Struct] = None, + create_betas: bool = True, + betas: Optional[Tensor] = None, + num_betas: int = 10, + create_global_orient: bool = True, + global_orient: Optional[Tensor] = None, + create_body_pose: bool = True, + body_pose: Optional[Tensor] = None, + create_transl: bool = True, + transl: Optional[Tensor] = None, + dtype=torch.float32, + batch_size: int = 1, + joint_mapper=None, + gender: str = 'neutral', + age: str = 'adult', + vertex_ids: Dict[str, int] = None, + v_template: Optional[Union[Tensor, Array]] = None, + **kwargs + ) -> None: + ''' SMPL model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + data_struct: Strct + A struct object. If given, then the parameters of the model are + read from the object. Otherwise, the model tries to read the + parameters from the given `model_path`. (default = None) + create_global_orient: bool, optional + Flag for creating a member variable for the global orientation + of the body. (default = True) + global_orient: torch.tensor, optional, Bx3 + The default value for the global orientation variable. + (default = None) + create_body_pose: bool, optional + Flag for creating a member variable for the pose of the body. + (default = True) + body_pose: torch.tensor, optional, Bx(Body Joints * 3) + The default value for the body pose variable. + (default = None) + num_betas: int, optional + Number of shape components to use + (default = 10). + create_betas: bool, optional + Flag for creating a member variable for the shape space + (default = True). + betas: torch.tensor, optional, Bx10 + The default value for the shape member variable. + (default = None) + create_transl: bool, optional + Flag for creating a member variable for the translation + of the body. (default = True) + transl: torch.tensor, optional, Bx3 + The default value for the transl variable. + (default = None) + dtype: torch.dtype, optional + The data type for the created variables + batch_size: int, optional + The batch size used for creating the member variables + joint_mapper: object, optional + An object that re-maps the joints. Useful if one wants to + re-order the SMPL joints to some other convention (e.g. MSCOCO) + (default = None) + gender: str, optional + Which gender to load + vertex_ids: dict, optional + A dictionary containing the indices of the extra vertices that + will be selected + ''' + + self.gender = gender + self.age = age + + if data_struct is None: + if osp.isdir(model_path): + model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl') + smpl_path = os.path.join(model_path, model_fn) + else: + smpl_path = model_path + assert osp.exists(smpl_path), 'Path {} does not exist!'.format( + smpl_path) + + with open(smpl_path, 'rb') as smpl_file: + data_struct = Struct(**pickle.load(smpl_file, + encoding='latin1')) + + super(SMPL, self).__init__() + self.batch_size = batch_size + shapedirs = data_struct.shapedirs + if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM): + print(f'WARNING: You are using a {self.name()} model, with only' + ' 10 shape coefficients.') + num_betas = min(num_betas, 10) + else: + num_betas = min(num_betas, self.SHAPE_SPACE_DIM) + + if self.age=='kid': + v_template_smil = np.load(kid_template_path) + v_template_smil -= np.mean(v_template_smil, axis=0) + v_template_diff = np.expand_dims(v_template_smil - data_struct.v_template, axis=2) + shapedirs = np.concatenate((shapedirs[:, :, :num_betas], v_template_diff), axis=2) + num_betas = num_betas + 1 + + self._num_betas = num_betas + shapedirs = shapedirs[:, :, :num_betas] + # The shape components + self.register_buffer( + 'shapedirs', + to_tensor(to_np(shapedirs), dtype=dtype)) + + if vertex_ids is None: + # SMPL and SMPL-H share the same topology, so any extra joints can + # be drawn from the same place + vertex_ids = VERTEX_IDS['smplh'] + + self.dtype = dtype + + self.joint_mapper = joint_mapper + + self.vertex_joint_selector = VertexJointSelector( + vertex_ids=vertex_ids, **kwargs) + + self.faces = data_struct.f + self.register_buffer('faces_tensor', + to_tensor(to_np(self.faces, dtype=np.int64), + dtype=torch.long)) + + if create_betas: + if betas is None: + default_betas = torch.zeros( + [batch_size, self.num_betas], dtype=dtype) + else: + if torch.is_tensor(betas): + default_betas = betas.clone().detach() + else: + default_betas = torch.tensor(betas, dtype=dtype) + + self.register_parameter( + 'betas', nn.Parameter(default_betas, requires_grad=True)) + + # The tensor that contains the global rotation of the model + # It is separated from the pose of the joints in case we wish to + # optimize only over one of them + if create_global_orient: + if global_orient is None: + default_global_orient = torch.zeros( + [batch_size, 3], dtype=dtype) + else: + if torch.is_tensor(global_orient): + default_global_orient = global_orient.clone().detach() + else: + default_global_orient = torch.tensor( + global_orient, dtype=dtype) + + global_orient = nn.Parameter(default_global_orient, + requires_grad=True) + self.register_parameter('global_orient', global_orient) + + if create_body_pose: + if body_pose is None: + default_body_pose = torch.zeros( + [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) + else: + if torch.is_tensor(body_pose): + default_body_pose = body_pose.clone().detach() + else: + default_body_pose = torch.tensor(body_pose, + dtype=dtype) + self.register_parameter( + 'body_pose', + nn.Parameter(default_body_pose, requires_grad=True)) + + if create_transl: + if transl is None: + default_transl = torch.zeros([batch_size, 3], + dtype=dtype, + requires_grad=True) + else: + default_transl = torch.tensor(transl, dtype=dtype) + self.register_parameter( + 'transl', nn.Parameter(default_transl, requires_grad=True)) + + if v_template is None: + v_template = data_struct.v_template + if not torch.is_tensor(v_template): + v_template = to_tensor(to_np(v_template), dtype=dtype) + # The vertices of the template model + self.register_buffer('v_template', v_template) + + j_regressor = to_tensor(to_np( + data_struct.J_regressor), dtype=dtype) + self.register_buffer('J_regressor', j_regressor) + + # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207 + num_pose_basis = data_struct.posedirs.shape[-1] + # 207 x 20670 + posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T + self.register_buffer('posedirs', + to_tensor(to_np(posedirs), dtype=dtype)) + + # indices of parents for each joints + parents = to_tensor(to_np(data_struct.kintree_table[0])).long() + parents[0] = -1 + self.register_buffer('parents', parents) + + lbs_weights = to_tensor(to_np(data_struct.weights), dtype=dtype) + self.register_buffer('lbs_weights', lbs_weights) + + @property + def num_betas(self): + return self._num_betas + + @property + def num_expression_coeffs(self): + return 0 + + def create_mean_pose(self, data_struct) -> Tensor: + pass + + def name(self) -> str: + return 'SMPL' + + @torch.no_grad() + def reset_params(self, **params_dict) -> None: + for param_name, param in self.named_parameters(): + if param_name in params_dict: + param[:] = torch.tensor(params_dict[param_name]) + else: + param.fill_(0) + + def get_num_verts(self) -> int: + return self.v_template.shape[0] + + def get_num_faces(self) -> int: + return self.faces.shape[0] + + def extra_repr(self) -> str: + msg = [ + f'Gender: {self.gender.upper()}', + f'Number of joints: {self.J_regressor.shape[0]}', + f'Betas: {self.num_betas}', + ] + return '\n'.join(msg) + + def forward_shape( + self, + betas: Optional[Tensor] = None, + ) -> SMPLOutput: + betas = betas if betas is not None else self.betas + v_shaped = self.v_template + blend_shapes(betas, self.shapedirs) + return SMPLOutput(vertices=v_shaped, betas=betas, v_shaped=v_shaped) + + def forward( + self, + betas: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts=True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> SMPLOutput: + ''' Forward pass for the SMPL model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + body_pose: torch.tensor, optional, shape Bx(J*3) + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + axis-angle format. (default=None) + transl: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `transl` and use it + instead. For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + ''' + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + body_pose = body_pose if body_pose is not None else self.body_pose + betas = betas if betas is not None else self.betas + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None and hasattr(self, 'transl'): + transl = self.transl + + full_pose = torch.cat([global_orient, body_pose], dim=1) + + batch_size = max(betas.shape[0], global_orient.shape[0], + body_pose.shape[0]) + + if betas.shape[0] != batch_size: + num_repeats = int(batch_size / betas.shape[0]) + betas = betas.expand(num_repeats, -1) + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=pose2rot) + + joints = self.vertex_joint_selector(vertices, joints) + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLOutput(vertices=vertices if return_verts else None, + global_orient=global_orient, + body_pose=body_pose, + joints=joints, + betas=betas, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLLayer(SMPL): + def __init__( + self, + *args, + **kwargs + ) -> None: + # Just create a SMPL module without any member variables + super(SMPLLayer, self).__init__( + create_body_pose=False, + create_betas=False, + create_global_orient=False, + create_transl=False, + *args, + **kwargs, + ) + + def forward( + self, + betas: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts=True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> SMPLOutput: + ''' Forward pass for the SMPL model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + Global rotation of the body. Useful if someone wishes to + predicts this with an external model. It is expected to be in + rotation matrix format. (default=None) + betas: torch.tensor, optional, shape BxN_b + Shape parameters. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + body_pose: torch.tensor, optional, shape BxJx3x3 + Body pose. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + ''' + model_vars = [betas, global_orient, body_pose, transl] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand( + batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + full_pose = torch.cat( + [global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)], + dim=1) + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, + pose2rot=False) + + joints = self.vertex_joint_selector(vertices, joints) + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if transl is not None: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLOutput(vertices=vertices if return_verts else None, + global_orient=global_orient, + body_pose=body_pose, + joints=joints, + betas=betas, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLH(SMPL): + + # The hand joints are replaced by MANO + NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2 + NUM_HAND_JOINTS = 15 + NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + + def __init__( + self, model_path, + kid_template_path: str = '', + data_struct: Optional[Struct] = None, + create_left_hand_pose: bool = True, + left_hand_pose: Optional[Tensor] = None, + create_right_hand_pose: bool = True, + right_hand_pose: Optional[Tensor] = None, + use_pca: bool = True, + num_pca_comps: int = 6, + flat_hand_mean: bool = False, + batch_size: int = 1, + gender: str = 'neutral', + age: str = 'adult', + dtype=torch.float32, + vertex_ids=None, + use_compressed: bool = True, + ext: str = 'pkl', + **kwargs + ) -> None: + ''' SMPLH model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + data_struct: Strct + A struct object. If given, then the parameters of the model are + read from the object. Otherwise, the model tries to read the + parameters from the given `model_path`. (default = None) + create_left_hand_pose: bool, optional + Flag for creating a member variable for the pose of the left + hand. (default = True) + left_hand_pose: torch.tensor, optional, BxP + The default value for the left hand pose member variable. + (default = None) + create_right_hand_pose: bool, optional + Flag for creating a member variable for the pose of the right + hand. (default = True) + right_hand_pose: torch.tensor, optional, BxP + The default value for the right hand pose member variable. + (default = None) + num_pca_comps: int, optional + The number of PCA components to use for each hand. + (default = 6) + flat_hand_mean: bool, optional + If False, then the pose of the hand is initialized to False. + batch_size: int, optional + The batch size used for creating the member variables + gender: str, optional + Which gender to load + dtype: torch.dtype, optional + The data type for the created variables + vertex_ids: dict, optional + A dictionary containing the indices of the extra vertices that + will be selected + ''' + + self.num_pca_comps = num_pca_comps + # If no data structure is passed, then load the data from the given + # model folder + if data_struct is None: + # Load the model + if osp.isdir(model_path): + model_fn = 'SMPLH_{}.{ext}'.format(gender.upper(), ext=ext) + smplh_path = os.path.join(model_path, model_fn) + else: + smplh_path = model_path + assert osp.exists(smplh_path), 'Path {} does not exist!'.format( + smplh_path) + + if ext == 'pkl': + with open(smplh_path, 'rb') as smplh_file: + model_data = pickle.load(smplh_file, encoding='latin1') + elif ext == 'npz': + model_data = np.load(smplh_path, allow_pickle=True) + else: + raise ValueError('Unknown extension: {}'.format(ext)) + data_struct = Struct(**model_data) + + if vertex_ids is None: + vertex_ids = VERTEX_IDS['smplh'] + + super(SMPLH, self).__init__( + model_path=model_path, + kid_template_path=kid_template_path, + data_struct=data_struct, + batch_size=batch_size, vertex_ids=vertex_ids, gender=gender, age=age, + use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs) + + self.use_pca = use_pca + self.num_pca_comps = num_pca_comps + self.flat_hand_mean = flat_hand_mean + + left_hand_components = data_struct.hands_componentsl[:num_pca_comps] + right_hand_components = data_struct.hands_componentsr[:num_pca_comps] + + self.np_left_hand_components = left_hand_components + self.np_right_hand_components = right_hand_components + if self.use_pca: + self.register_buffer( + 'left_hand_components', + torch.tensor(left_hand_components, dtype=dtype)) + self.register_buffer( + 'right_hand_components', + torch.tensor(right_hand_components, dtype=dtype)) + + if self.flat_hand_mean: + left_hand_mean = np.zeros_like(data_struct.hands_meanl) + else: + left_hand_mean = data_struct.hands_meanl + + if self.flat_hand_mean: + right_hand_mean = np.zeros_like(data_struct.hands_meanr) + else: + right_hand_mean = data_struct.hands_meanr + + self.register_buffer('left_hand_mean', + to_tensor(left_hand_mean, dtype=self.dtype)) + self.register_buffer('right_hand_mean', + to_tensor(right_hand_mean, dtype=self.dtype)) + + # Create the buffers for the pose of the left hand + hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS + if create_left_hand_pose: + if left_hand_pose is None: + default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], + dtype=dtype) + else: + default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype) + + left_hand_pose_param = nn.Parameter(default_lhand_pose, + requires_grad=True) + self.register_parameter('left_hand_pose', + left_hand_pose_param) + + if create_right_hand_pose: + if right_hand_pose is None: + default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], + dtype=dtype) + else: + default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype) + + right_hand_pose_param = nn.Parameter(default_rhand_pose, + requires_grad=True) + self.register_parameter('right_hand_pose', + right_hand_pose_param) + + # Create the buffer for the mean pose. + pose_mean_tensor = self.create_mean_pose( + data_struct, flat_hand_mean=flat_hand_mean) + if not torch.is_tensor(pose_mean_tensor): + pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype) + self.register_buffer('pose_mean', pose_mean_tensor) + + def create_mean_pose(self, data_struct, flat_hand_mean=False): + # Create the array for the mean pose. If flat_hand is false, then use + # the mean that is given by the data, rather than the flat open hand + global_orient_mean = torch.zeros([3], dtype=self.dtype) + body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], + dtype=self.dtype) + + pose_mean = torch.cat([global_orient_mean, body_pose_mean, + self.left_hand_mean, + self.right_hand_mean], dim=0) + return pose_mean + + def name(self) -> str: + return 'SMPL+H' + + def extra_repr(self): + msg = super(SMPLH, self).extra_repr() + msg = [msg] + if self.use_pca: + msg.append(f'Number of PCA components: {self.num_pca_comps}') + msg.append(f'Flat hand mean: {self.flat_hand_mean}') + return '\n'.join(msg) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> SMPLHOutput: + ''' + ''' + + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + body_pose = body_pose if body_pose is not None else self.body_pose + betas = betas if betas is not None else self.betas + left_hand_pose = (left_hand_pose if left_hand_pose is not None else + self.left_hand_pose) + right_hand_pose = (right_hand_pose if right_hand_pose is not None else + self.right_hand_pose) + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + if self.use_pca: + left_hand_pose = torch.einsum( + 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) + right_hand_pose = torch.einsum( + 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) + + full_pose = torch.cat([global_orient, body_pose, + left_hand_pose, + right_hand_pose], dim=1) + full_pose += self.pose_mean + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=pose2rot) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLHOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLHLayer(SMPLH): + + def __init__( + self, *args, **kwargs + ) -> None: + ''' SMPL+H as a layer model constructor + ''' + super(SMPLHLayer, self).__init__( + create_global_orient=False, + create_body_pose=False, + create_left_hand_pose=False, + create_right_hand_pose=False, + create_betas=False, + create_transl=False, + *args, + **kwargs) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> SMPLHOutput: + ''' Forward pass for the SMPL+H model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + Global rotation of the body. Useful if someone wishes to + predicts this with an external model. It is expected to be in + rotation matrix format. (default=None) + betas: torch.tensor, optional, shape BxN_b + Shape parameters. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + body_pose: torch.tensor, optional, shape BxJx3x3 + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + left_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the left hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + right_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the right hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + ''' + model_vars = [betas, global_orient, body_pose, transl, left_hand_pose, + right_hand_pose] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous() + if left_hand_pose is None: + left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if right_hand_pose is None: + right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + # Concatenate all pose vectors + full_pose = torch.cat( + [global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), + left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), + right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)], + dim=1) + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=False) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if transl is not None: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLHOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class SMPLX(SMPLH): + ''' + SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters + trained jointly for the face, hands and body. + SMPL-X uses standard vertex based linear blend skinning with learned + corrective blend shapes, has N=10475 vertices and K=54 joints, + which includes joints for the neck, jaw, eyeballs and fingers. + ''' + + NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS + NUM_HAND_JOINTS = 15 + NUM_FACE_JOINTS = 3 + NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS + EXPRESSION_SPACE_DIM = 100 + NECK_IDX = 12 + + def __init__( + self, model_path: str, + kid_template_path: str = '', + num_expression_coeffs: int = 10, + create_expression: bool = True, + expression: Optional[Tensor] = None, + create_jaw_pose: bool = True, + jaw_pose: Optional[Tensor] = None, + create_leye_pose: bool = True, + leye_pose: Optional[Tensor] = None, + create_reye_pose=True, + reye_pose: Optional[Tensor] = None, + use_face_contour: bool = False, + batch_size: int = 1, + gender: str = 'neutral', + age: str = 'adult', + dtype=torch.float32, + ext: str = 'npz', + **kwargs + ) -> None: + ''' SMPLX model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + num_expression_coeffs: int, optional + Number of expression components to use + (default = 10). + create_expression: bool, optional + Flag for creating a member variable for the expression space + (default = True). + expression: torch.tensor, optional, Bx10 + The default value for the expression member variable. + (default = None) + create_jaw_pose: bool, optional + Flag for creating a member variable for the jaw pose. + (default = False) + jaw_pose: torch.tensor, optional, Bx3 + The default value for the jaw pose variable. + (default = None) + create_leye_pose: bool, optional + Flag for creating a member variable for the left eye pose. + (default = False) + leye_pose: torch.tensor, optional, Bx10 + The default value for the left eye pose variable. + (default = None) + create_reye_pose: bool, optional + Flag for creating a member variable for the right eye pose. + (default = False) + reye_pose: torch.tensor, optional, Bx10 + The default value for the right eye pose variable. + (default = None) + use_face_contour: bool, optional + Whether to compute the keypoints that form the facial contour + batch_size: int, optional + The batch size used for creating the member variables + gender: str, optional + Which gender to load + dtype: torch.dtype + The data type for the created variables + ''' + + # Load the model + if osp.isdir(model_path): + model_fn = 'SMPLX_{}.{ext}'.format(gender.upper(), ext=ext) + smplx_path = os.path.join(model_path, model_fn) + else: + smplx_path = model_path + assert osp.exists(smplx_path), 'Path {} does not exist!'.format( + smplx_path) + + if ext == 'pkl': + with open(smplx_path, 'rb') as smplx_file: + model_data = pickle.load(smplx_file, encoding='latin1') + elif ext == 'npz': + model_data = np.load(smplx_path, allow_pickle=True) + else: + raise ValueError('Unknown extension: {}'.format(ext)) + + data_struct = Struct(**model_data) + + super(SMPLX, self).__init__( + model_path=model_path, + kid_template_path=kid_template_path, + data_struct=data_struct, + dtype=dtype, + batch_size=batch_size, + vertex_ids=VERTEX_IDS['smplx'], + gender=gender, age=age, ext=ext, + **kwargs) + + lmk_faces_idx = data_struct.lmk_faces_idx + self.register_buffer('lmk_faces_idx', + torch.tensor(lmk_faces_idx, dtype=torch.long)) + lmk_bary_coords = data_struct.lmk_bary_coords + self.register_buffer('lmk_bary_coords', + torch.tensor(lmk_bary_coords, dtype=dtype)) + + self.use_face_contour = use_face_contour + if self.use_face_contour: + dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx + dynamic_lmk_faces_idx = torch.tensor( + dynamic_lmk_faces_idx, + dtype=torch.long) + self.register_buffer('dynamic_lmk_faces_idx', + dynamic_lmk_faces_idx) + + dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords + dynamic_lmk_bary_coords = torch.tensor( + dynamic_lmk_bary_coords, dtype=dtype) + self.register_buffer('dynamic_lmk_bary_coords', + dynamic_lmk_bary_coords) + + neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) + self.register_buffer( + 'neck_kin_chain', + torch.tensor(neck_kin_chain, dtype=torch.long)) + + if create_jaw_pose: + if jaw_pose is None: + default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) + jaw_pose_param = nn.Parameter(default_jaw_pose, + requires_grad=True) + self.register_parameter('jaw_pose', jaw_pose_param) + + if create_leye_pose: + if leye_pose is None: + default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_leye_pose = torch.tensor(leye_pose, dtype=dtype) + leye_pose_param = nn.Parameter(default_leye_pose, + requires_grad=True) + self.register_parameter('leye_pose', leye_pose_param) + + if create_reye_pose: + if reye_pose is None: + default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_reye_pose = torch.tensor(reye_pose, dtype=dtype) + reye_pose_param = nn.Parameter(default_reye_pose, + requires_grad=True) + self.register_parameter('reye_pose', reye_pose_param) + + shapedirs = data_struct.shapedirs + if len(shapedirs.shape) < 3: + shapedirs = shapedirs[:, :, None] + if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + + self.EXPRESSION_SPACE_DIM): + print(f'WARNING: You are using a {self.name()} model, with only' + ' 10 shape and 10 expression coefficients.') + expr_start_idx = 10 + expr_end_idx = 20 + num_expression_coeffs = min(num_expression_coeffs, 10) + else: + expr_start_idx = self.SHAPE_SPACE_DIM + expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs + num_expression_coeffs = min( + num_expression_coeffs, self.EXPRESSION_SPACE_DIM) + + self._num_expression_coeffs = num_expression_coeffs + + expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] + self.register_buffer( + 'expr_dirs', to_tensor(to_np(expr_dirs), dtype=dtype)) + + if create_expression: + if expression is None: + default_expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype) + else: + default_expression = torch.tensor(expression, dtype=dtype) + expression_param = nn.Parameter(default_expression, + requires_grad=True) + self.register_parameter('expression', expression_param) + + def name(self) -> str: + return 'SMPL-X' + + @property + def num_expression_coeffs(self): + return self._num_expression_coeffs + + def create_mean_pose(self, data_struct, flat_hand_mean=False): + # Create the array for the mean pose. If flat_hand is false, then use + # the mean that is given by the data, rather than the flat open hand + global_orient_mean = torch.zeros([3], dtype=self.dtype) + body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], + dtype=self.dtype) + jaw_pose_mean = torch.zeros([3], dtype=self.dtype) + leye_pose_mean = torch.zeros([3], dtype=self.dtype) + reye_pose_mean = torch.zeros([3], dtype=self.dtype) + + pose_mean = np.concatenate([global_orient_mean, body_pose_mean, + jaw_pose_mean, + leye_pose_mean, reye_pose_mean, + self.left_hand_mean, self.right_hand_mean], + axis=0) + + return pose_mean + + def extra_repr(self): + msg = super(SMPLX, self).extra_repr() + msg = [ + msg, + f'Number of Expression Coefficients: {self.num_expression_coeffs}' + ] + return '\n'.join(msg) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + return_shaped: bool = True, + v_template: Optional[Tensor] = None, + **kwargs + ) -> SMPLXOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + If given, ignore the member variable `expression` and use it + instead. For example, it can used if expression parameters + `expression` are predicted from some external model. + body_pose: torch.tensor, optional, shape Bx(J*3) + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + axis-angle format. (default=None) + left_hand_pose: torch.tensor, optional, shape BxP + If given, ignore the member variable `left_hand_pose` and + use this instead. It should either contain PCA coefficients or + joint rotations in axis-angle format. + right_hand_pose: torch.tensor, optional, shape BxP + If given, ignore the member variable `right_hand_pose` and + use this instead. It should either contain PCA coefficients or + joint rotations in axis-angle format. + jaw_pose: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `jaw_pose` and + use this instead. It should either joint rotations in + axis-angle format. + transl: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `transl` and use it + instead. For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + output: ModelOutput + A named tuple of type `ModelOutput` + ''' + + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + body_pose = body_pose if body_pose is not None else self.body_pose + betas = betas if betas is not None else self.betas + + left_hand_pose = (left_hand_pose if left_hand_pose is not None else + self.left_hand_pose) + right_hand_pose = (right_hand_pose if right_hand_pose is not None else + self.right_hand_pose) + jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose + leye_pose = leye_pose if leye_pose is not None else self.leye_pose + reye_pose = reye_pose if reye_pose is not None else self.reye_pose + expression = expression if expression is not None else self.expression + v_template = v_template if v_template is not None else self.v_template + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + if self.use_pca: + left_hand_pose = torch.einsum( + 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) + right_hand_pose = torch.einsum( + 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) + + full_pose = torch.cat([global_orient.reshape(-1, 1, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3), + jaw_pose.reshape(-1, 1, 3), + leye_pose.reshape(-1, 1, 3), + reye_pose.reshape(-1, 1, 3), + left_hand_pose.reshape(-1, 15, 3), + right_hand_pose.reshape(-1, 15, 3)], + dim=1).reshape(-1, 165) + + # Add the mean pose of the model. Does not affect the body, only the + # hands when flat_hand_mean == False + full_pose += self.pose_mean + + batch_size = max(betas.shape[0], global_orient.shape[0], + body_pose.shape[0]) + # Concatenate the shape and expression coefficients + scale = int(batch_size / betas.shape[0]) + if scale > 1: + betas = betas.expand(scale, -1) + shape_components = torch.cat([betas, expression], dim=-1) + + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints, vT, jT, v_shaped, v_posed = lbs(shape_components, full_pose, v_template, + shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=pose2rot, + custom_out=True + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze( + dim=0).expand(batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + self.batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, full_pose, self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=True, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + + lmk_faces_idx = torch.cat([lmk_faces_idx, + dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), + dyn_lmk_bary_coords], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, + lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + joints_transform = self.vertex_joint_selector(vT, jT) + + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + # Map the joints to the current dataset + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + joints_transform[:, :, :3, 3] += transl.unsqueeze(dim=1) + + output = SMPLXOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + jaw_pose=jaw_pose, + v_shaped=v_shaped, + v_posed=v_posed, + joints_transform=joints_transform, + full_pose=full_pose if return_full_pose else None) + return output + + +class SMPLXLayer(SMPLX): + def __init__( + self, + *args, + **kwargs + ) -> None: + # Just create a SMPLX module without any member variables + super(SMPLXLayer, self).__init__( + create_global_orient=False, + create_body_pose=False, + create_left_hand_pose=False, + create_right_hand_pose=False, + create_jaw_pose=False, + create_leye_pose=False, + create_reye_pose=False, + create_betas=False, + create_expression=False, + create_transl=False, + *args, **kwargs, + ) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + body_pose: Optional[Tensor] = None, + left_hand_pose: Optional[Tensor] = None, + right_hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs + ) -> SMPLXOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. It is expected to be in rotation matrix + format. (default=None) + betas: torch.tensor, optional, shape BxN_b + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + Expression coefficients. + For example, it can used if expression parameters + `expression` are predicted from some external model. + body_pose: torch.tensor, optional, shape BxJx3x3 + If given, ignore the member variable `body_pose` and use it + instead. For example, it can used if someone predicts the + pose of the body joints are predicted from some external model. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + left_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the left hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + right_hand_pose: torch.tensor, optional, shape Bx15x3x3 + If given, contains the pose of the right hand. + It should be a tensor that contains joint rotations in + rotation matrix format. (default=None) + jaw_pose: torch.tensor, optional, shape Bx3x3 + Jaw pose. It should either joint rotations in + rotation matrix format. + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full pose vector (default=False) + Returns + ------- + output: ModelOutput + A data class that contains the posed vertices and joints + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + + model_vars = [betas, global_orient, body_pose, transl, + expression, left_hand_pose, right_hand_pose, jaw_pose] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + + if global_orient is None: + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if body_pose is None: + body_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand( + batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous() + if left_hand_pose is None: + left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if right_hand_pose is None: + right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if jaw_pose is None: + jaw_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if leye_pose is None: + leye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if reye_pose is None: + reye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if expression is None: + expression = torch.zeros([batch_size, self.num_expression_coeffs], + dtype=dtype, device=device) + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + # Concatenate all pose vectors + full_pose = torch.cat( + [global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), + jaw_pose.reshape(-1, 1, 3, 3), + leye_pose.reshape(-1, 1, 3, 3), + reye_pose.reshape(-1, 1, 3, 3), + left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), + right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)], + dim=1) + shape_components = torch.cat([betas, expression], dim=-1) + + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs(shape_components, full_pose, self.v_template, + shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, + pose2rot=False, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze( + dim=0).expand(batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, full_pose, + self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=False, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + + lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), + dyn_lmk_bary_coords], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, + lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + # Map the joints to the current dataset + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + if transl is not None: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = SMPLXOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + jaw_pose=jaw_pose, + transl=transl, + full_pose=full_pose if return_full_pose else None) + return output + + +class MANO(SMPL): + # The hand joints are replaced by MANO + NUM_BODY_JOINTS = 1 + NUM_HAND_JOINTS = 15 + NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS + + def __init__( + self, + model_path: str, + is_rhand: bool = True, + data_struct: Optional[Struct] = None, + create_hand_pose: bool = True, + hand_pose: Optional[Tensor] = None, + use_pca: bool = True, + num_pca_comps: int = 6, + flat_hand_mean: bool = False, + batch_size: int = 1, + dtype=torch.float32, + vertex_ids=None, + use_compressed: bool = True, + ext: str = 'pkl', + **kwargs + ) -> None: + ''' MANO model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + data_struct: Strct + A struct object. If given, then the parameters of the model are + read from the object. Otherwise, the model tries to read the + parameters from the given `model_path`. (default = None) + create_hand_pose: bool, optional + Flag for creating a member variable for the pose of the right + hand. (default = True) + hand_pose: torch.tensor, optional, BxP + The default value for the right hand pose member variable. + (default = None) + num_pca_comps: int, optional + The number of PCA components to use for each hand. + (default = 6) + flat_hand_mean: bool, optional + If False, then the pose of the hand is initialized to False. + batch_size: int, optional + The batch size used for creating the member variables + dtype: torch.dtype, optional + The data type for the created variables + vertex_ids: dict, optional + A dictionary containing the indices of the extra vertices that + will be selected + ''' + + self.num_pca_comps = num_pca_comps + self.is_rhand = is_rhand + # If no data structure is passed, then load the data from the given + # model folder + if data_struct is None: + # Load the model + if osp.isdir(model_path): + model_fn = 'MANO_{}.{ext}'.format( + 'RIGHT' if is_rhand else 'LEFT', ext=ext) + mano_path = os.path.join(model_path, model_fn) + else: + mano_path = model_path + self.is_rhand = True if 'RIGHT' in os.path.basename( + model_path) else False + assert osp.exists(mano_path), 'Path {} does not exist!'.format( + mano_path) + + if ext == 'pkl': + with open(mano_path, 'rb') as mano_file: + model_data = pickle.load(mano_file, encoding='latin1') + elif ext == 'npz': + model_data = np.load(mano_path, allow_pickle=True) + else: + raise ValueError('Unknown extension: {}'.format(ext)) + data_struct = Struct(**model_data) + + if vertex_ids is None: + vertex_ids = VERTEX_IDS['smplh'] + + super(MANO, self).__init__( + model_path=model_path, data_struct=data_struct, + batch_size=batch_size, vertex_ids=vertex_ids, + use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs) + + # add only MANO tips to the extra joints + self.vertex_joint_selector.extra_joints_idxs = to_tensor( + list(VERTEX_IDS['mano'].values()), dtype=torch.long) + + self.use_pca = use_pca + self.num_pca_comps = num_pca_comps + if self.num_pca_comps == 45: + self.use_pca = False + self.flat_hand_mean = flat_hand_mean + + hand_components = data_struct.hands_components[:num_pca_comps] + + self.np_hand_components = hand_components + + if self.use_pca: + self.register_buffer( + 'hand_components', + torch.tensor(hand_components, dtype=dtype)) + + if self.flat_hand_mean: + hand_mean = np.zeros_like(data_struct.hands_mean) + else: + hand_mean = data_struct.hands_mean + + self.register_buffer('hand_mean', + to_tensor(hand_mean, dtype=self.dtype)) + + # Create the buffers for the pose of the left hand + hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS + if create_hand_pose: + if hand_pose is None: + default_hand_pose = torch.zeros([batch_size, hand_pose_dim], + dtype=dtype) + else: + default_hand_pose = torch.tensor(hand_pose, dtype=dtype) + + hand_pose_param = nn.Parameter(default_hand_pose, + requires_grad=True) + self.register_parameter('hand_pose', + hand_pose_param) + + # Create the buffer for the mean pose. + pose_mean = self.create_mean_pose( + data_struct, flat_hand_mean=flat_hand_mean) + pose_mean_tensor = pose_mean.clone().to(dtype) + # pose_mean_tensor = torch.tensor(pose_mean, dtype=dtype) + self.register_buffer('pose_mean', pose_mean_tensor) + + def name(self) -> str: + return 'MANO' + + def create_mean_pose(self, data_struct, flat_hand_mean=False): + # Create the array for the mean pose. If flat_hand is false, then use + # the mean that is given by the data, rather than the flat open hand + global_orient_mean = torch.zeros([3], dtype=self.dtype) + pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0) + return pose_mean + + def extra_repr(self): + msg = [super(MANO, self).extra_repr()] + if self.use_pca: + msg.append(f'Number of PCA components: {self.num_pca_comps}') + msg.append(f'Flat hand mean: {self.flat_hand_mean}') + return '\n'.join(msg) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs + ) -> MANOOutput: + ''' Forward pass for the MANO model + ''' + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + betas = betas if betas is not None else self.betas + hand_pose = (hand_pose if hand_pose is not None else + self.hand_pose) + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + if self.use_pca: + hand_pose = torch.einsum( + 'bi,ij->bj', [hand_pose, self.hand_components]) + + full_pose = torch.cat([global_orient, hand_pose], dim=1) + full_pose += self.pose_mean + + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=True, + ) + + # # Add pre-selected extra joints that might be needed + # joints = self.vertex_joint_selector(vertices, joints) + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if apply_trans: + joints = joints + transl.unsqueeze(dim=1) + vertices = vertices + transl.unsqueeze(dim=1) + + output = MANOOutput(vertices=vertices if return_verts else None, + joints=joints if return_verts else None, + betas=betas, + global_orient=global_orient, + hand_pose=hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class MANOLayer(MANO): + def __init__(self, *args, **kwargs) -> None: + ''' MANO as a layer model constructor + ''' + super(MANOLayer, self).__init__( + create_global_orient=False, + create_hand_pose=False, + create_betas=False, + create_transl=False, + *args, **kwargs) + + def name(self) -> str: + return 'MANO' + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + hand_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs + ) -> MANOOutput: + ''' Forward pass for the MANO model + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + batch_size = 1 + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + else: + batch_size = global_orient.shape[0] + if hand_pose is None: + hand_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + if betas is None: + betas = torch.zeros( + [batch_size, self.num_betas], dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + full_pose = torch.cat([global_orient, hand_pose], dim=1) + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=False) + + if self.joint_mapper is not None: + joints = self.joint_mapper(joints) + + if transl is not None: + joints = joints + transl.unsqueeze(dim=1) + vertices = vertices + transl.unsqueeze(dim=1) + + output = MANOOutput( + vertices=vertices if return_verts else None, + joints=joints if return_verts else None, + betas=betas, + global_orient=global_orient, + hand_pose=hand_pose, + full_pose=full_pose if return_full_pose else None) + + return output + + +class FLAME(SMPL): + NUM_JOINTS = 5 + SHAPE_SPACE_DIM = 300 + EXPRESSION_SPACE_DIM = 100 + NECK_IDX = 0 + + def __init__( + self, + model_path: str, + data_struct=None, + num_expression_coeffs=10, + create_expression: bool = True, + expression: Optional[Tensor] = None, + create_neck_pose: bool = True, + neck_pose: Optional[Tensor] = None, + create_jaw_pose: bool = True, + jaw_pose: Optional[Tensor] = None, + create_leye_pose: bool = True, + leye_pose: Optional[Tensor] = None, + create_reye_pose=True, + reye_pose: Optional[Tensor] = None, + use_face_contour=False, + batch_size: int = 1, + gender: str = 'neutral', + dtype: torch.dtype = torch.float32, + ext='pkl', + **kwargs + ) -> None: + ''' FLAME model constructor + + Parameters + ---------- + model_path: str + The path to the folder or to the file where the model + parameters are stored + num_expression_coeffs: int, optional + Number of expression components to use + (default = 10). + create_expression: bool, optional + Flag for creating a member variable for the expression space + (default = True). + expression: torch.tensor, optional, Bx10 + The default value for the expression member variable. + (default = None) + create_neck_pose: bool, optional + Flag for creating a member variable for the neck pose. + (default = False) + neck_pose: torch.tensor, optional, Bx3 + The default value for the neck pose variable. + (default = None) + create_jaw_pose: bool, optional + Flag for creating a member variable for the jaw pose. + (default = False) + jaw_pose: torch.tensor, optional, Bx3 + The default value for the jaw pose variable. + (default = None) + create_leye_pose: bool, optional + Flag for creating a member variable for the left eye pose. + (default = False) + leye_pose: torch.tensor, optional, Bx10 + The default value for the left eye pose variable. + (default = None) + create_reye_pose: bool, optional + Flag for creating a member variable for the right eye pose. + (default = False) + reye_pose: torch.tensor, optional, Bx10 + The default value for the right eye pose variable. + (default = None) + use_face_contour: bool, optional + Whether to compute the keypoints that form the facial contour + batch_size: int, optional + The batch size used for creating the member variables + gender: str, optional + Which gender to load + dtype: torch.dtype + The data type for the created variables + ''' + model_fn = f'FLAME_{gender.upper()}.{ext}' + flame_path = os.path.join(model_path, model_fn) + assert osp.exists(flame_path), 'Path {} does not exist!'.format( + flame_path) + if ext == 'npz': + file_data = np.load(flame_path, allow_pickle=True) + elif ext == 'pkl': + with open(flame_path, 'rb') as smpl_file: + file_data = pickle.load(smpl_file, encoding='latin1') + else: + raise ValueError('Unknown extension: {}'.format(ext)) + data_struct = Struct(**file_data) + + super(FLAME, self).__init__( + model_path=model_path, + data_struct=data_struct, + dtype=dtype, + batch_size=batch_size, + gender=gender, + ext=ext, + **kwargs) + + self.use_face_contour = use_face_contour + + self.vertex_joint_selector.extra_joints_idxs = to_tensor( + [], dtype=torch.long) + + if create_neck_pose: + if neck_pose is None: + default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_neck_pose = torch.tensor(neck_pose, dtype=dtype) + neck_pose_param = nn.Parameter( + default_neck_pose, requires_grad=True) + self.register_parameter('neck_pose', neck_pose_param) + + if create_jaw_pose: + if jaw_pose is None: + default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) + jaw_pose_param = nn.Parameter(default_jaw_pose, + requires_grad=True) + self.register_parameter('jaw_pose', jaw_pose_param) + + if create_leye_pose: + if leye_pose is None: + default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_leye_pose = torch.tensor(leye_pose, dtype=dtype) + leye_pose_param = nn.Parameter(default_leye_pose, + requires_grad=True) + self.register_parameter('leye_pose', leye_pose_param) + + if create_reye_pose: + if reye_pose is None: + default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) + else: + default_reye_pose = torch.tensor(reye_pose, dtype=dtype) + reye_pose_param = nn.Parameter(default_reye_pose, + requires_grad=True) + self.register_parameter('reye_pose', reye_pose_param) + + shapedirs = data_struct.shapedirs + if len(shapedirs.shape) < 3: + shapedirs = shapedirs[:, :, None] + if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + + self.EXPRESSION_SPACE_DIM): + print(f'WARNING: You are using a {self.name()} model, with only' + ' 10 shape and 10 expression coefficients.') + expr_start_idx = 10 + expr_end_idx = 20 + num_expression_coeffs = min(num_expression_coeffs, 10) + else: + expr_start_idx = self.SHAPE_SPACE_DIM + expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs + num_expression_coeffs = min( + num_expression_coeffs, self.EXPRESSION_SPACE_DIM) + + self._num_expression_coeffs = num_expression_coeffs + + expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] + self.register_buffer( + 'expr_dirs', to_tensor(to_np(expr_dirs), dtype=dtype)) + + if create_expression: + if expression is None: + default_expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype) + else: + default_expression = torch.tensor(expression, dtype=dtype) + expression_param = nn.Parameter(default_expression, + requires_grad=True) + self.register_parameter('expression', expression_param) + + # The pickle file that contains the barycentric coordinates for + # regressing the landmarks + landmark_bcoord_filename = osp.join( + model_path, 'flame_static_embedding.pkl') + + with open(landmark_bcoord_filename, 'rb') as fp: + landmarks_data = pickle.load(fp, encoding='latin1') + + lmk_faces_idx = landmarks_data['lmk_face_idx'].astype(np.int64) + self.register_buffer('lmk_faces_idx', + torch.tensor(lmk_faces_idx, dtype=torch.long)) + lmk_bary_coords = landmarks_data['lmk_b_coords'] + self.register_buffer('lmk_bary_coords', + torch.tensor(lmk_bary_coords, dtype=dtype)) + if self.use_face_contour: + face_contour_path = os.path.join( + model_path, 'flame_dynamic_embedding.npy') + contour_embeddings = np.load(face_contour_path, + allow_pickle=True, + encoding='latin1')[()] + + dynamic_lmk_faces_idx = np.array( + contour_embeddings['lmk_face_idx'], dtype=np.int64) + dynamic_lmk_faces_idx = torch.tensor( + dynamic_lmk_faces_idx, + dtype=torch.long) + self.register_buffer('dynamic_lmk_faces_idx', + dynamic_lmk_faces_idx) + + dynamic_lmk_b_coords = torch.tensor( + contour_embeddings['lmk_b_coords'], dtype=dtype) + self.register_buffer( + 'dynamic_lmk_bary_coords', dynamic_lmk_b_coords) + + neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) + self.register_buffer( + 'neck_kin_chain', + torch.tensor(neck_kin_chain, dtype=torch.long)) + + @property + def num_expression_coeffs(self): + return self._num_expression_coeffs + + def name(self) -> str: + return 'FLAME' + + def extra_repr(self): + msg = [ + super(FLAME, self).extra_repr(), + f'Number of Expression Coefficients: {self.num_expression_coeffs}', + f'Use face contour: {self.use_face_contour}', + ] + return '\n'.join(msg) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + neck_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> FLAMEOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3 + If given, ignore the member variable and use it as the global + rotation of the body. Useful if someone wishes to predicts this + with an external model. (default=None) + betas: torch.tensor, optional, shape Bx10 + If given, ignore the member variable `betas` and use it + instead. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape Bx10 + If given, ignore the member variable `expression` and use it + instead. For example, it can used if expression parameters + `expression` are predicted from some external model. + jaw_pose: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `jaw_pose` and + use this instead. It should either joint rotations in + axis-angle format. + jaw_pose: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `jaw_pose` and + use this instead. It should either joint rotations in + axis-angle format. + transl: torch.tensor, optional, shape Bx3 + If given, ignore the member variable `transl` and use it + instead. For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + output: ModelOutput + A named tuple of type `ModelOutput` + ''' + + # If no shape and pose parameters are passed along, then use the + # ones from the module + global_orient = (global_orient if global_orient is not None else + self.global_orient) + jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose + neck_pose = neck_pose if neck_pose is not None else self.neck_pose + + leye_pose = leye_pose if leye_pose is not None else self.leye_pose + reye_pose = reye_pose if reye_pose is not None else self.reye_pose + + betas = betas if betas is not None else self.betas + expression = expression if expression is not None else self.expression + + apply_trans = transl is not None or hasattr(self, 'transl') + if transl is None: + if hasattr(self, 'transl'): + transl = self.transl + + full_pose = torch.cat( + [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) + + batch_size = max(betas.shape[0], global_orient.shape[0], + jaw_pose.shape[0]) + # Concatenate the shape and expression coefficients + scale = int(batch_size / betas.shape[0]) + if scale > 1: + betas = betas.expand(scale, -1) + shape_components = torch.cat([betas, expression], dim=-1) + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs(shape_components, full_pose, self.v_template, + shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=pose2rot, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze( + dim=0).expand(batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, full_pose, self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=True, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + lmk_faces_idx = torch.cat([lmk_faces_idx, + dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), + dyn_lmk_bary_coords], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, + lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + if apply_trans: + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = FLAMEOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + neck_pose=neck_pose, + jaw_pose=jaw_pose, + full_pose=full_pose if return_full_pose else None) + return output + + +class FLAMELayer(FLAME): + def __init__(self, *args, **kwargs) -> None: + ''' FLAME as a layer model constructor ''' + super(FLAMELayer, self).__init__( + create_betas=False, + create_expression=False, + create_global_orient=False, + create_neck_pose=False, + create_jaw_pose=False, + create_leye_pose=False, + create_reye_pose=False, + *args, + **kwargs) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + neck_pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + expression: Optional[Tensor] = None, + jaw_pose: Optional[Tensor] = None, + leye_pose: Optional[Tensor] = None, + reye_pose: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + pose2rot: bool = True, + **kwargs + ) -> FLAMEOutput: + ''' + Forward pass for the SMPLX model + + Parameters + ---------- + global_orient: torch.tensor, optional, shape Bx3x3 + Global rotation of the body. Useful if someone wishes to + predicts this with an external model. It is expected to be in + rotation matrix format. (default=None) + betas: torch.tensor, optional, shape BxN_b + Shape parameters. For example, it can used if shape parameters + `betas` are predicted from some external model. + (default=None) + expression: torch.tensor, optional, shape BxN_e + If given, ignore the member variable `expression` and use it + instead. For example, it can used if expression parameters + `expression` are predicted from some external model. + jaw_pose: torch.tensor, optional, shape Bx3x3 + Jaw pose. It should either joint rotations in + rotation matrix format. + transl: torch.tensor, optional, shape Bx3 + Translation vector of the body. + For example, it can used if the translation + `transl` is predicted from some external model. + (default=None) + return_verts: bool, optional + Return the vertices. (default=True) + return_full_pose: bool, optional + Returns the full axis-angle pose vector (default=False) + + Returns + ------- + output: ModelOutput + A named tuple of type `ModelOutput` + ''' + device, dtype = self.shapedirs.device, self.shapedirs.dtype + if global_orient is None: + batch_size = 1 + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + else: + batch_size = global_orient.shape[0] + if neck_pose is None: + neck_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous() + if jaw_pose is None: + jaw_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if leye_pose is None: + leye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if reye_pose is None: + reye_pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + if betas is None: + betas = torch.zeros([batch_size, self.num_betas], + dtype=dtype, device=device) + if expression is None: + expression = torch.zeros([batch_size, self.num_expression_coeffs], + dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + full_pose = torch.cat( + [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1) + + shape_components = torch.cat([betas, expression], dim=-1) + shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) + + vertices, joints = lbs(shape_components, full_pose, self.v_template, + shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=False, + ) + + lmk_faces_idx = self.lmk_faces_idx.unsqueeze( + dim=0).expand(batch_size, -1).contiguous() + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( + batch_size, 1, 1) + if self.use_face_contour: + lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( + vertices, full_pose, self.dynamic_lmk_faces_idx, + self.dynamic_lmk_bary_coords, + self.neck_kin_chain, + pose2rot=False, + ) + dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords + lmk_faces_idx = torch.cat([lmk_faces_idx, + dyn_lmk_faces_idx], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), + dyn_lmk_bary_coords], 1) + + landmarks = vertices2landmarks(vertices, self.faces_tensor, + lmk_faces_idx, + lmk_bary_coords) + + # Add any extra joints that might be needed + joints = self.vertex_joint_selector(vertices, joints) + # Add the landmarks to the joints + joints = torch.cat([joints, landmarks], dim=1) + + # Map the joints to the current dataset + if self.joint_mapper is not None: + joints = self.joint_mapper(joints=joints, vertices=vertices) + + joints += transl.unsqueeze(dim=1) + vertices += transl.unsqueeze(dim=1) + + output = FLAMEOutput(vertices=vertices if return_verts else None, + joints=joints, + betas=betas, + expression=expression, + global_orient=global_orient, + neck_pose=neck_pose, + jaw_pose=jaw_pose, + full_pose=full_pose if return_full_pose else None) + return output + + +def build_layer( + model_path: str, + model_type: str = 'smpl', + **kwargs +) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]: + ''' Method for creating a model from a path and a model type + + Parameters + ---------- + model_path: str + Either the path to the model you wish to load or a folder, + where each subfolder contains the differents types, i.e.: + model_path: + | + |-- smpl + |-- SMPL_FEMALE + |-- SMPL_NEUTRAL + |-- SMPL_MALE + |-- smplh + |-- SMPLH_FEMALE + |-- SMPLH_MALE + |-- smplx + |-- SMPLX_FEMALE + |-- SMPLX_NEUTRAL + |-- SMPLX_MALE + |-- mano + |-- MANO RIGHT + |-- MANO LEFT + |-- flame + |-- FLAME_FEMALE + |-- FLAME_MALE + |-- FLAME_NEUTRAL + + model_type: str, optional + When model_path is a folder, then this parameter specifies the + type of model to be loaded + **kwargs: dict + Keyword arguments + + Returns + ------- + body_model: nn.Module + The PyTorch module that implements the corresponding body model + Raises + ------ + ValueError: In case the model type is not one of SMPL, SMPLH, + SMPLX, MANO or FLAME + ''' + + if osp.isdir(model_path): + model_path = os.path.join(model_path, model_type) + else: + model_type = osp.basename(model_path).split('_')[0].lower() + + if model_type.lower() == 'smpl': + return SMPLLayer(model_path, **kwargs) + elif model_type.lower() == 'smplh': + return SMPLHLayer(model_path, **kwargs) + elif model_type.lower() == 'smplx': + return SMPLXLayer(model_path, **kwargs) + elif 'mano' in model_type.lower(): + return MANOLayer(model_path, **kwargs) + elif 'flame' in model_type.lower(): + return FLAMELayer(model_path, **kwargs) + else: + raise ValueError(f'Unknown model type {model_type}, exiting!') + + +def create( + model_path: str, + model_type: str = 'smpl', + **kwargs +) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: + ''' Method for creating a model from a path and a model type + + Parameters + ---------- + model_path: str + Either the path to the model you wish to load or a folder, + where each subfolder contains the differents types, i.e.: + model_path: + | + |-- smpl + |-- SMPL_FEMALE + |-- SMPL_NEUTRAL + |-- SMPL_MALE + |-- smplh + |-- SMPLH_FEMALE + |-- SMPLH_MALE + |-- smplx + |-- SMPLX_FEMALE + |-- SMPLX_NEUTRAL + |-- SMPLX_MALE + |-- mano + |-- MANO RIGHT + |-- MANO LEFT + + model_type: str, optional + When model_path is a folder, then this parameter specifies the + type of model to be loaded + **kwargs: dict + Keyword arguments + + Returns + ------- + body_model: nn.Module + The PyTorch module that implements the corresponding body model + Raises + ------ + ValueError: In case the model type is not one of SMPL, SMPLH, + SMPLX, MANO or FLAME + ''' + + # If it's a folder, assume + if osp.isdir(model_path): + model_path = os.path.join(model_path, model_type) + else: + model_type = osp.basename(model_path).split('_')[0].lower() + + if model_type.lower() == 'smpl': + return SMPL(model_path, **kwargs) + elif model_type.lower() == 'smplh': + return SMPLH(model_path, **kwargs) + elif model_type.lower() == 'smplx': + return SMPLX(model_path, **kwargs) + elif 'mano' in model_type.lower(): + return MANO(model_path, **kwargs) + elif 'flame' in model_type.lower(): + return FLAME(model_path, **kwargs) + else: + raise ValueError(f'Unknown model type {model_type}, exiting!') diff --git a/TADA/smplx/joint_names.py b/TADA/smplx/joint_names.py new file mode 100644 index 0000000000000000000000000000000000000000..b7326ffdeaf5b61e616c31cfddd18249143e2118 --- /dev/null +++ b/TADA/smplx/joint_names.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +JOINT_NAMES = [ + 'pelvis', + 'left_hip', + 'right_hip', + 'spine1', + 'left_knee', + 'right_knee', + 'spine2', + 'left_ankle', + 'right_ankle', + 'spine3', + 'left_foot', + 'right_foot', + 'neck', + 'left_collar', + 'right_collar', + 'head', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', + 'jaw', + 'left_eye_smplhf', + 'right_eye_smplhf', + 'left_index1', + 'left_index2', + 'left_index3', + 'left_middle1', + 'left_middle2', + 'left_middle3', + 'left_pinky1', + 'left_pinky2', + 'left_pinky3', + 'left_ring1', + 'left_ring2', + 'left_ring3', + 'left_thumb1', + 'left_thumb2', + 'left_thumb3', + 'right_index1', + 'right_index2', + 'right_index3', + 'right_middle1', + 'right_middle2', + 'right_middle3', + 'right_pinky1', + 'right_pinky2', + 'right_pinky3', + 'right_ring1', + 'right_ring2', + 'right_ring3', + 'right_thumb1', + 'right_thumb2', + 'right_thumb3', + 'nose', + 'right_eye', + 'left_eye', + 'right_ear', + 'left_ear', + 'left_big_toe', + 'left_small_toe', + 'left_heel', + 'right_big_toe', + 'right_small_toe', + 'right_heel', + 'left_thumb', + 'left_index', + 'left_middle', + 'left_ring', + 'left_pinky', + 'right_thumb', + 'right_index', + 'right_middle', + 'right_ring', + 'right_pinky', + 'right_eye_brow1', + 'right_eye_brow2', + 'right_eye_brow3', + 'right_eye_brow4', + 'right_eye_brow5', + 'left_eye_brow5', + 'left_eye_brow4', + 'left_eye_brow3', + 'left_eye_brow2', + 'left_eye_brow1', + 'nose1', + 'nose2', + 'nose3', + 'nose4', + 'right_nose_2', + 'right_nose_1', + 'nose_middle', + 'left_nose_1', + 'left_nose_2', + 'right_eye1', + 'right_eye2', + 'right_eye3', + 'right_eye4', + 'right_eye5', + 'right_eye6', + 'left_eye4', + 'left_eye3', + 'left_eye2', + 'left_eye1', + 'left_eye6', + 'left_eye5', + 'right_mouth_1', + 'right_mouth_2', + 'right_mouth_3', + 'mouth_top', + 'left_mouth_3', + 'left_mouth_2', + 'left_mouth_1', + 'left_mouth_5', # 59 in OpenPose output + 'left_mouth_4', # 58 in OpenPose output + 'mouth_bottom', + 'right_mouth_4', + 'right_mouth_5', + 'right_lip_1', + 'right_lip_2', + 'lip_top', + 'left_lip_2', + 'left_lip_1', + 'left_lip_3', + 'lip_bottom', + 'right_lip_3', + # Face contour + 'right_contour_1', + 'right_contour_2', + 'right_contour_3', + 'right_contour_4', + 'right_contour_5', + 'right_contour_6', + 'right_contour_7', + 'right_contour_8', + 'contour_middle', + 'left_contour_8', + 'left_contour_7', + 'left_contour_6', + 'left_contour_5', + 'left_contour_4', + 'left_contour_3', + 'left_contour_2', + 'left_contour_1', +] + + +SMPLH_JOINT_NAMES = [ + 'pelvis', + 'left_hip', + 'right_hip', + 'spine1', + 'left_knee', + 'right_knee', + 'spine2', + 'left_ankle', + 'right_ankle', + 'spine3', + 'left_foot', + 'right_foot', + 'neck', + 'left_collar', + 'right_collar', + 'head', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', + 'left_index1', + 'left_index2', + 'left_index3', + 'left_middle1', + 'left_middle2', + 'left_middle3', + 'left_pinky1', + 'left_pinky2', + 'left_pinky3', + 'left_ring1', + 'left_ring2', + 'left_ring3', + 'left_thumb1', + 'left_thumb2', + 'left_thumb3', + 'right_index1', + 'right_index2', + 'right_index3', + 'right_middle1', + 'right_middle2', + 'right_middle3', + 'right_pinky1', + 'right_pinky2', + 'right_pinky3', + 'right_ring1', + 'right_ring2', + 'right_ring3', + 'right_thumb1', + 'right_thumb2', + 'right_thumb3', + 'nose', + 'right_eye', + 'left_eye', + 'right_ear', + 'left_ear', + 'left_big_toe', + 'left_small_toe', + 'left_heel', + 'right_big_toe', + 'right_small_toe', + 'right_heel', + 'left_thumb', + 'left_index', + 'left_middle', + 'left_ring', + 'left_pinky', + 'right_thumb', + 'right_index', + 'right_middle', + 'right_ring', + 'right_pinky', +] diff --git a/TADA/smplx/lbs.py b/TADA/smplx/lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..c42fe1bc007fb3163e3a223f73d62d6046e29638 --- /dev/null +++ b/TADA/smplx/lbs.py @@ -0,0 +1,408 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from typing import Tuple, List +import numpy as np + +import torch +import torch.nn.functional as F + +from .utils import rot_mat_to_euler, Tensor + + +def find_dynamic_lmk_idx_and_bcoords( + vertices: Tensor, + pose: Tensor, + dynamic_lmk_faces_idx: Tensor, + dynamic_lmk_b_coords: Tensor, + neck_kin_chain: List[int], + pose2rot: bool = True, +) -> Tuple[Tensor, Tensor]: + ''' Compute the faces, barycentric coordinates for the dynamic landmarks + + + To do so, we first compute the rotation of the neck around the y-axis + and then use a pre-computed look-up table to find the faces and the + barycentric coordinates that will be used. + + Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) + for providing the original TensorFlow implementation and for the LUT. + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + pose: torch.tensor Bx(Jx3), dtype = torch.float32 + The current pose of the body model + dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long + The look-up table from neck rotation to faces + dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 + The look-up table from neck rotation to barycentric coordinates + neck_kin_chain: list + A python list that contains the indices of the joints that form the + kinematic chain of the neck. + dtype: torch.dtype, optional + + Returns + ------- + dyn_lmk_faces_idx: torch.tensor, dtype = torch.long + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + ''' + + dtype = vertices.dtype + batch_size = vertices.shape[0] + + if pose2rot: + aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, + neck_kin_chain) + rot_mats = batch_rodrigues( + aa_pose.view(-1, 3)).view(batch_size, -1, 3, 3) + else: + rot_mats = torch.index_select( + pose.view(batch_size, -1, 3, 3), 1, neck_kin_chain) + + rel_rot_mat = torch.eye( + 3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0).repeat( + batch_size, 1, 1) + for idx in range(len(neck_kin_chain)): + rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) + + y_rot_angle = torch.round( + torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, + max=39)).to(dtype=torch.long) + neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) + mask = y_rot_angle.lt(-39).to(dtype=torch.long) + neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) + y_rot_angle = (neg_mask * neg_vals + + (1 - neg_mask) * y_rot_angle) + + dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, + 0, y_rot_angle) + dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, + 0, y_rot_angle) + + return dyn_lmk_faces_idx, dyn_lmk_b_coords + + +def vertices2landmarks( + vertices: Tensor, + faces: Tensor, + lmk_faces_idx: Tensor, + lmk_bary_coords: Tensor +) -> Tensor: + ''' Calculates landmarks by barycentric interpolation + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + faces: torch.tensor Fx3, dtype = torch.long + The faces of the mesh + lmk_faces_idx: torch.tensor L, dtype = torch.long + The tensor with the indices of the faces used to calculate the + landmarks. + lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 + The tensor of barycentric coordinates that are used to interpolate + the landmarks + + Returns + ------- + landmarks: torch.tensor BxLx3, dtype = torch.float32 + The coordinates of the landmarks for each mesh in the batch + ''' + # Extract the indices of the vertices for each face + # BxLx3 + batch_size, num_verts = vertices.shape[:2] + device = vertices.device + + lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( + batch_size, -1, 3) + + lmk_faces += torch.arange( + batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts + + lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( + batch_size, -1, 3, 3) + + landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) + return landmarks + + +def lbs( + betas: Tensor, + pose: Tensor, + v_template: Tensor, + shapedirs: Tensor, + posedirs: Tensor, + J_regressor: Tensor, + parents: Tensor, + lbs_weights: Tensor, + pose2rot: bool = True, + custom_out: bool = False, +): + ''' Performs Linear Blend Skinning with the given shape and pose parameters + + Parameters + ---------- + betas : torch.tensor BxNB + The tensor of shape parameters + pose : torch.tensor Bx(J + 1) * 3 + The pose parameters in axis-angle format + v_template torch.tensor BxVx3 + The template mesh that will be deformed + shapedirs : torch.tensor 1xNB + The tensor of PCA shape displacements + posedirs : torch.tensor Px(V * 3) + The pose PCA coefficients + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from + the position of the vertices + parents: torch.tensor J + The array that describes the kinematic tree for the model + lbs_weights: torch.tensor N x V x (J + 1) + The linear blend skinning weights that represent how much the + rotation matrix of each part affects each vertex + pose2rot: bool, optional + Flag on whether to convert the input pose tensor to rotation + matrices. The default value is True. If False, then the pose tensor + should already contain rotation matrices and have a size of + Bx(J + 1)x9 + dtype: torch.dtype, optional + + custom_out: return A T if true + + Returns + ------- + verts: torch.tensor BxVx3 + The vertices of the mesh after applying the shape and pose + displacements. + joints: torch.tensor BxJx3 + The joints of the model + + ''' + + batch_size = max(betas.shape[0], pose.shape[0]) + device, dtype = betas.device, betas.dtype + + # Add shape contribution + v_shaped = v_template + blend_shapes(betas, shapedirs) + + # Get the joints + # NxJx3 array + J = vertices2joints(J_regressor, v_shaped) + + # 3. Add pose blend shapes + # N x J x 3 x 3 + ident = torch.eye(3, dtype=dtype, device=device) + if pose2rot: + rot_mats = batch_rodrigues(pose.view(-1, 3)).view( + [batch_size, -1, 3, 3]) + + pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) + # (N x P) x (P, V * 3) -> N x V x 3 + pose_offsets = torch.matmul( + pose_feature, posedirs).view(batch_size, -1, 3) + else: + pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident + rot_mats = pose.view(batch_size, -1, 3, 3) + + pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), + posedirs).view(batch_size, -1, 3) + + v_posed = pose_offsets + v_shaped + # 4. Get the global joint location + J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) + + # 5. Do skinning: + # W is N x V x (J + 1) + W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) + # (N x V x (J + 1)) x (N x (J + 1) x 16) + num_joints = J_regressor.shape[0] + T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ + .view(batch_size, -1, 4, 4) + + homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], + dtype=dtype, device=device) + v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) + v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) + + verts = v_homo[:, :, :3, 0] + + if custom_out: + return verts, J_transformed, T, A, v_shaped, v_posed + + return verts, J_transformed, v_shaped, v_posed + + +def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor: + ''' Calculates the 3D joint locations from the vertices + + Parameters + ---------- + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from the + position of the vertices + vertices : torch.tensor BxVx3 + The tensor of mesh vertices + + Returns + ------- + torch.tensor BxJx3 + The location of the joints + ''' + + return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) + + +def blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor: + ''' Calculates the per vertex displacement due to the blend shapes + + + Parameters + ---------- + betas : torch.tensor Bx(num_betas) + Blend shape coefficients + shape_disps: torch.tensor Vx3x(num_betas) + Blend shapes + + Returns + ------- + torch.tensor BxVx3 + The per-vertex displacement due to shape deformation + ''' + + # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] + # i.e. Multiply each shape displacement by its corresponding beta and + # then sum them. + blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) + return blend_shape + + +def batch_rodrigues( + rot_vecs: Tensor, + epsilon: float = 1e-8, +) -> Tensor: + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + + batch_size = rot_vecs.shape[0] + device, dtype = rot_vecs.device, rot_vecs.dtype + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def transform_mat(R: Tensor, t: Tensor) -> Tensor: + ''' Creates a batch of transformation matrices + Args: + - R: Bx3x3 array of a batch of rotation matrices + - t: Bx3x1 array of a batch of translation vectors + Returns: + - T: Bx4x4 Transformation matrix + ''' + # No padding left or right, only add an extra row + return torch.cat([F.pad(R, [0, 0, 0, 1]), + F.pad(t, [0, 0, 0, 1], value=1)], dim=2) + + +def batch_rigid_transform( + rot_mats: Tensor, + joints: Tensor, + parents: Tensor, + dtype=torch.float32 +) -> Tensor: + """ + Applies a batch of rigid transformations to the joints + + Parameters + ---------- + rot_mats : torch.tensor BxNx3x3 + Tensor of rotation matrices + joints : torch.tensor BxNx3 + Locations of joints + parents : torch.tensor BxN + The kinematic tree of each object + dtype : torch.dtype, optional: + The data type of the created tensors, the default is torch.float32 + + Returns + ------- + posed_joints : torch.tensor BxNx3 + The locations of the joints after applying the pose rotations + rel_transforms : torch.tensor BxNx4x4 + The relative (with respect to the root joint) rigid transformations + for all the joints + """ + + joints = torch.unsqueeze(joints, dim=-1) + + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, parents[1:]] + + transforms_mat = transform_mat( + rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[parents[i]], + transforms_mat[:, i]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=1) + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + joints_homogen = F.pad(joints, [0, 0, 0, 1]) + + rel_transforms = transforms - F.pad( + torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) + + return posed_joints, rel_transforms diff --git a/TADA/smplx/utils.py b/TADA/smplx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3387e1713d55e78d345c952302ef42965d045912 --- /dev/null +++ b/TADA/smplx/utils.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import NewType, Union, Optional +from dataclasses import dataclass, asdict, fields +import numpy as np +import torch + +Tensor = NewType('Tensor', torch.Tensor) +Array = NewType('Array', np.ndarray) + + +@dataclass +class ModelOutput: + vertices: Optional[Tensor] = None + joints: Optional[Tensor] = None + full_pose: Optional[Tensor] = None + global_orient: Optional[Tensor] = None + transl: Optional[Tensor] = None + v_shaped: Optional[Tensor] = None + + def __getitem__(self, key): + return getattr(self, key) + + def get(self, key, default=None): + return getattr(self, key, default) + + def __iter__(self): + return self.keys() + + def keys(self): + keys = [t.name for t in fields(self)] + return iter(keys) + + def values(self): + values = [getattr(self, t.name) for t in fields(self)] + return iter(values) + + def items(self): + data = [(t.name, getattr(self, t.name)) for t in fields(self)] + return iter(data) + + +@dataclass +class SMPLOutput(ModelOutput): + betas: Optional[Tensor] = None + body_pose: Optional[Tensor] = None + + +@dataclass +class SMPLHOutput(SMPLOutput): + left_hand_pose: Optional[Tensor] = None + right_hand_pose: Optional[Tensor] = None + transl: Optional[Tensor] = None + + +@dataclass +class SMPLXOutput(SMPLHOutput): + expression: Optional[Tensor] = None + jaw_pose: Optional[Tensor] = None + joints_transform: Optional[Tensor] = None + v_posed: Optional[Tensor] = None + + +@dataclass +class MANOOutput(ModelOutput): + betas: Optional[Tensor] = None + hand_pose: Optional[Tensor] = None + + +@dataclass +class FLAMEOutput(ModelOutput): + betas: Optional[Tensor] = None + expression: Optional[Tensor] = None + jaw_pose: Optional[Tensor] = None + neck_pose: Optional[Tensor] = None + + +def find_joint_kin_chain(joint_id, kinematic_tree): + kin_chain = [] + curr_idx = joint_id + while curr_idx != -1: + kin_chain.append(curr_idx) + curr_idx = kinematic_tree[curr_idx] + return kin_chain + + +def to_tensor( + array: Union[Array, Tensor], dtype=torch.float32 +) -> Tensor: + if torch.is_tensor(array): + return array + else: + return torch.tensor(array, dtype=dtype) + + +class Struct(object): + def __init__(self, **kwargs): + for key, val in kwargs.items(): + setattr(self, key, val) + + +def to_np(array, dtype=np.float32): + if 'scipy.sparse' in str(type(array)): + array = array.todense() + return np.array(array, dtype=dtype) + + +def rot_mat_to_euler(rot_mats): + # Calculates rotation matrix to euler angles + # Careful for extreme cases of eular angles like [0.0, pi, 0.0] + + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + return torch.atan2(-rot_mats[:, 2, 0], sy) diff --git a/TADA/smplx/vertex_ids.py b/TADA/smplx/vertex_ids.py new file mode 100644 index 0000000000000000000000000000000000000000..0e7a4c36700f002da54a9e181eabbd47af2a95bc --- /dev/null +++ b/TADA/smplx/vertex_ids.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import print_function +from __future__ import absolute_import +from __future__ import division + +# Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to +# MSCOCO and OpenPose joints +vertex_ids = { + 'smplh': { + 'nose': 332, + 'reye': 6260, + 'leye': 2800, + 'rear': 4071, + 'lear': 583, + 'rthumb': 6191, + 'rindex': 5782, + 'rmiddle': 5905, + 'rring': 6016, + 'rpinky': 6133, + 'lthumb': 2746, + 'lindex': 2319, + 'lmiddle': 2445, + 'lring': 2556, + 'lpinky': 2673, + 'LBigToe': 3216, + 'LSmallToe': 3226, + 'LHeel': 3387, + 'RBigToe': 6617, + 'RSmallToe': 6624, + 'RHeel': 6787 + }, + 'smplx': { + 'nose': 9120, + 'reye': 9929, + 'leye': 9448, + 'rear': 616, + 'lear': 6, + 'rthumb': 8079, + 'rindex': 7669, + 'rmiddle': 7794, + 'rring': 7905, + 'rpinky': 8022, + 'lthumb': 5361, + 'lindex': 4933, + 'lmiddle': 5058, + 'lring': 5169, + 'lpinky': 5286, + 'LBigToe': 5770, + 'LSmallToe': 5780, + 'LHeel': 8846, + 'RBigToe': 8463, + 'RSmallToe': 8474, + 'RHeel': 8635 + }, + 'mano': { + 'thumb': 744, + 'index': 320, + 'middle': 443, + 'ring': 554, + 'pinky': 671, + } +} diff --git a/TADA/smplx/vertex_joint_selector.py b/TADA/smplx/vertex_joint_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8298bd5e087731f86c1c699703b5219e046c5c --- /dev/null +++ b/TADA/smplx/vertex_joint_selector.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np + +import torch +import torch.nn as nn + +from .utils import to_tensor + + +class VertexJointSelector(nn.Module): + + def __init__(self, vertex_ids=None, + use_hands=True, + use_feet_keypoints=True, **kwargs): + super(VertexJointSelector, self).__init__() + + extra_joints_idxs = [] + + face_keyp_idxs = np.array([ + vertex_ids['nose'], + vertex_ids['reye'], + vertex_ids['leye'], + vertex_ids['rear'], + vertex_ids['lear']], dtype=np.int64) + + extra_joints_idxs = np.concatenate([extra_joints_idxs, + face_keyp_idxs]) + + if use_feet_keypoints: + feet_keyp_idxs = np.array([vertex_ids['LBigToe'], + vertex_ids['LSmallToe'], + vertex_ids['LHeel'], + vertex_ids['RBigToe'], + vertex_ids['RSmallToe'], + vertex_ids['RHeel']], dtype=np.int32) + + extra_joints_idxs = np.concatenate( + [extra_joints_idxs, feet_keyp_idxs]) + + if use_hands: + self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky'] + + tips_idxs = [] + for hand_id in ['l', 'r']: + for tip_name in self.tip_names: + tips_idxs.append(vertex_ids[hand_id + tip_name]) + + extra_joints_idxs = np.concatenate( + [extra_joints_idxs, tips_idxs]) + + self.register_buffer('extra_joints_idxs', + to_tensor(extra_joints_idxs, dtype=torch.long)) + + def forward(self, vertices, joints): + extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs) + joints = torch.cat([joints, extra_joints], dim=1) + + return joints diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac73feb32af51996db999fe4044d6db400cb3c1 --- /dev/null +++ b/app.py @@ -0,0 +1,182 @@ +import os, sys +import gradio as gr +from huggingface_hub import snapshot_download +css = """ +.dfile {height: 85px} +.ov {height: 185px} +""" + + +from huggingface_hub import snapshot_download +from motion.visual_api import Visualize +import moviepy.editor as mpy +import torch +import json + +with open("motion/path.json", "r") as f: + json_dict = json.load(f) + +def ref_video_fn(path_of_ref_video): + if path_of_ref_video is not None: + return gr.update(value=True) + else: + return gr.update(value=False) + +def prepare(): + if not os.path.exists("body_models") or not os.path.exists("weights"): + REPO_ID = 'Kleinhe/CAMD' + snapshot_download(repo_id=REPO_ID, local_dir='./', local_dir_use_symlinks=False) + + if not os.path.exists("tada-extend"): + import subprocess + import platform + command = "bash scripts/tada_google.sh" + subprocess.call(command, shell=platform.system() != 'Windows') + +def demo(prompt, mode, condition, render_mode="joints", skip_steps=0, out_size=1024, tada_role=None): + prompt = prompt + if prompt is None: + prompt = "" + + path = None + out_paths = [None, None, None] + joints_paths = [None, None, None] + smpl_paths = [None, None, None] + + if tada_role == "None": + tada_role = None + + for i in range(len(mode)): + kargs = { + "mode":mode[i], + "device":"cuda" if torch.cuda.is_available() else "cpu", + "condition":condition, + "smpl_path":json_dict["smpl_path"], + "skip_steps":skip_steps, + "path":json_dict, + "tada_base":json_dict["tada_base"], + "tada_role":tada_role + } + visual = Visualize(**kargs) + render_mode = render_mode + + joint_path = "results/joints/{}_joint.npy".format(mode[i]) + smpl_path = "results/joints/{}_smpl.npy".format(mode[i]) + + output = visual.predict(prompt, path, render_mode, joint_path, smpl_path) + + if render_mode == "joints": + pics = visual.joints_process(output, prompt) + elif render_mode == "pyrender": + meshes, _ = visual.get_mesh(output) + pics = visual.pyrender_process(meshes, out_size, out_size) + + out_path = "results/motion/temp{}.mp4".format(i) + vid = mpy.ImageSequenceClip([x[:, :, :] for x in pics], fps=20) + vid.write_videofile(out_path, remove_temp=True) + + if mode[i] == "cadm": + out_paths[0] = out_path + joints_paths[0] = joint_path + smpl_paths[0] = smpl_path + elif mode[i] == "cadm-augment": + out_paths[1] = out_path + joints_paths[1] = joint_path + smpl_paths[1] = smpl_path + elif mode[i] == "mdm": + out_paths[2] = out_path + joints_paths[2] = joint_path + smpl_paths[2] = smpl_path + + return out_paths + joints_paths + smpl_paths + + +def t2m_demo(): + prepare() + os.makedirs("results/motion", exist_ok=True) + os.makedirs("results/joints", exist_ok=True) + os.makedirs("results/smpls", exist_ok=True) + + tada_base = json_dict["tada_base"] + files = os.listdir(os.path.join(tada_base, "MESH")) + files = sorted(files) + if files[0].startswith("."): + files.pop(0) + files = ["None"] + files + + with gr.Blocks(analytics_enabled=False, css=css) as t2m_interface: + gr.Markdown("

🤷‍♂️ SemanticBoost: Elevating Motion Generation with Augmented Textual Cues

\ + Arxiv       \ + Homepage       \ + Github
") + + with gr.Row().style(equal_height=True): + with gr.Column(variant='panel'): + with gr.Tabs(): + with gr.TabItem('Settings'): + with gr.Column(variant='panel'): + with gr.Row(): + demo_mode = gr.CheckboxGroup(choices=['cadm', 'cadm-augment','mdm'], default=["cadm"], label='Mode', info="Choose models to run demos, more models cost more time.") + skip_steps = gr.Number(value=0, label="Skip-Steps", info="The number of skip-steps during diffusion process (0 -> 999)", minimum=0, maximum=999, precision=0) + + with gr.Row(): + condition = gr.Radio(['text', 'uncond'], value='text', label='Condition', info="If sythesize motion with prompt?") + out_size = gr.Number(value=1024, label="Resolution", info="The resolution of output videos", minimum=224, maximum=2048, precision=0) + + with gr.Row(): + render_mode = gr.Radio(['joints','pyrender'], value='joints', label='Render', info="If render results to 3D meshes? Pyrender need more time.") + tada_role = gr.Dropdown(files, value="None", multiselect=False, label="TADA Role", info="Choose 3D role to render") + + with gr.Row(): + prompt = gr.Textbox(value=None, placeholder="120,A person walks forward and does a handstand.", label="Prompt for Model -> (Length,Text)") + + submit = gr.Button('Visualize', variant='primary') + + with gr.Column(variant='panel'): + with gr.Tabs(): + with gr.TabItem('Results'): + with gr.Row(): + with gr.Column(): + gen_video = gr.Video(label="CADM", format="mp4", autoplay=True, elem_classes="ov") + with gr.Column(): + joint_file = gr.File(label="CADM-Joints", value=None, elem_classes="dfile") + smpl_file = gr.File(label="CADM-SMPL", value=None, elem_classes="dfile") + + with gr.Row(): + with gr.Column(): + gen_video1 = gr.Video(label="CADM-Augment", format="mp4", autoplay=True, elem_classes="ov") + with gr.Column(): + joint_file1 = gr.File(label="CADM-Augment-Joints", value=None, elem_classes="dfile") + smpl_file1 = gr.File(label="CADM-Augment-SMPL", value=None, elem_classes="dfile") + + with gr.Row(): + with gr.Column(): + gen_video2 = gr.Video(label="MDM", format="mp4", autoplay=True, elem_classes="ov") + with gr.Column(): + joint_file2 = gr.File(label="MDM-Joints", value=None, elem_classes="dfile") + smpl_file2 = gr.File(label="MDM-SMPL", value=None, elem_classes="dfile") + + + submit.click( + fn=demo, + inputs=[prompt, + demo_mode, + condition, + render_mode, + skip_steps, + out_size, + tada_role + ], + outputs=[gen_video, gen_video1, gen_video2, joint_file, joint_file1, joint_file2, smpl_file, smpl_file1, smpl_file2] + ) + + return t2m_interface + + +if __name__ == "__main__": + demo = t2m_demo() + demo.queue(max_size=10) + demo.launch(debug=True) + + + diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..0dfdd2c80bbcd1425e4ba80157bc26c2b56c979f --- /dev/null +++ b/inference.py @@ -0,0 +1,86 @@ +import torch +from motion.visual_api import Visualize +import moviepy.editor as mpy +import os, sys +import time +import json +import imageio +import argparse + +def interface(prompt, mode="cadm", render_mode="pyrender", out_size=1024, tada_role=None): + os.makedirs("results/motion", exist_ok=True) + os.makedirs("results/joints", exist_ok=True) + os.makedirs("results/smpls", exist_ok=True) + + name = prompt.replace("/", "_").replace(" ", "_").replace(",", "_").replace("#", "_").replace("|", "_").replace(".npy", "").replace(".txt", "").replace(".csv", "").replace(".", "").replace("'", "_") + name = "_".join(name.split("_")[:25]) + out_path = os.path.join("results/motion", name + ".mp4") + gif_path = os.path.join("results/motion", name + ".gif") + joint_path = os.path.join("results/jo ints", name + ".npy") + smpl_path = os.path.join("results/smpls", name + ".npy") + + ''' + prompt 输入为 length, prompt, 如果只输入 prompt, length 默认为 196 + mode 指不同的模型 + ''' + + assert mode in ["cadm", "cadm-augment", "mdm"] + assert render_mode in ["joints", "pyrender_fast", "pyrender_slow"] + path = None + + with open("motion/path.json", "r") as f: + json_dict = json.load(f) + + t1 = time.time() + + kargs = { + "mode":mode, + "device":"cuda" if torch.cuda.is_available() else "cpu", + "rotate":0, + "condition":"text", + "smpl_path":json_dict["smpl_path"], + "skip_steps":0, + "path":json_dict, + "tada_base":json_dict["tada_base"], + "tada_role":tada_role + } + visual = Visualize(**kargs) + + t2 = time.time() + + output = visual.predict(prompt, path, render_mode, joint_path, smpl_path) + + t3 = time.time() + + if render_mode == "joints": + pics = visual.joints_process(output, prompt, out_size, out_size) + elif render_mode.startswith("pyrender"): + meshes, _ = visual.get_mesh(output) + pics = visual.pyrender_process(meshes, out_size, out_size) + + vid = mpy.ImageSequenceClip([x[:, :, :] for x in pics], fps=20) + vid.write_videofile(out_path, remove_temp=True) + imageio.mimsave(gif_path, pics, duration= 1000 / 20, loop=0) + + t4 = time.time() + + cost_init = t2 - t1 + cost_infer = t3 - t2 + cost_render = t4 - t3 + + print("initial model cost time: %.4f, infer and fit cost time: %.4f, render cost time: %.4f, total cost time: %.4f"%(cost_init, cost_infer, cost_render, t4 - t1)) + + return out_path + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='visualize demo') + ############################ basic_setings ######################## + parser.add_argument('--prompt', type=str, default="120, A person walks forward and does a handstand.") + parser.add_argument('--mode', type=str, default="cadm", choices=['cadm', 'cadm-augment', "mdm"], help="choose model") + parser.add_argument("--render_mode", default="pyrender_slow", type=str, choices=["pyrender_slow", "pyrender_fast", "joints"]) + parser.add_argument("--size", default=1024, type=int) + parser.add_argument("--tada_role", default=None, type=str) + opt = parser.parse_args() + + + out_path = interface(opt.prompt, mode=opt.mode, render_mode=opt.render_mode, out_size=opt.size, tada_role=opt.tada_role) \ No newline at end of file diff --git a/motion/__pycache__/double_take.cpython-39.pyc b/motion/__pycache__/double_take.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74a3f4269c7347ecda4d6f4acf97bc66cbe0c09d Binary files /dev/null and b/motion/__pycache__/double_take.cpython-39.pyc differ diff --git a/motion/__pycache__/hybrik_loc2rot.cpython-310.pyc b/motion/__pycache__/hybrik_loc2rot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8833cd812a2d6a21848dcd1f603a45cb1ac82b49 Binary files /dev/null and b/motion/__pycache__/hybrik_loc2rot.cpython-310.pyc differ diff --git a/motion/__pycache__/hybrik_loc2rot.cpython-39.pyc b/motion/__pycache__/hybrik_loc2rot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99b178406388738842f044603e24c8bfb690bc69 Binary files /dev/null and b/motion/__pycache__/hybrik_loc2rot.cpython-39.pyc differ diff --git a/motion/__pycache__/model_util.cpython-310.pyc b/motion/__pycache__/model_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0121914e251cf04c25fe0ba6d31238db63d652ae Binary files /dev/null and b/motion/__pycache__/model_util.cpython-310.pyc differ diff --git a/motion/__pycache__/model_util.cpython-311.pyc b/motion/__pycache__/model_util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7acd60cb60258f6294b75141bec151821f09df90 Binary files /dev/null and b/motion/__pycache__/model_util.cpython-311.pyc differ diff --git a/motion/__pycache__/model_util.cpython-39.pyc b/motion/__pycache__/model_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7905f7810af265f050892d51cde7b146120ac400 Binary files /dev/null and b/motion/__pycache__/model_util.cpython-39.pyc differ diff --git a/motion/__pycache__/plot3d.cpython-310.pyc b/motion/__pycache__/plot3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1152942bf39044be1b9537497087063529fb9501 Binary files /dev/null and b/motion/__pycache__/plot3d.cpython-310.pyc differ diff --git a/motion/__pycache__/plot3d.cpython-311.pyc b/motion/__pycache__/plot3d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cb5e31cdee586dbe94ad4ea42eb96212fba3a46 Binary files /dev/null and b/motion/__pycache__/plot3d.cpython-311.pyc differ diff --git a/motion/__pycache__/plot3d.cpython-39.pyc b/motion/__pycache__/plot3d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da07950e4af9d076d7672b4cf250b72bb0905dc8 Binary files /dev/null and b/motion/__pycache__/plot3d.cpython-39.pyc differ diff --git a/motion/__pycache__/sample.cpython-310.pyc b/motion/__pycache__/sample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8283022d0cb310e2c06b11f4a4c0eaf4d680876 Binary files /dev/null and b/motion/__pycache__/sample.cpython-310.pyc differ diff --git a/motion/__pycache__/sample.cpython-311.pyc b/motion/__pycache__/sample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab2ea84b3d114fe00469ac954a83e704f7a92fbb Binary files /dev/null and b/motion/__pycache__/sample.cpython-311.pyc differ diff --git a/motion/__pycache__/sample.cpython-39.pyc b/motion/__pycache__/sample.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34541907e555a426a7e767107f0862a8e42a61e3 Binary files /dev/null and b/motion/__pycache__/sample.cpython-39.pyc differ diff --git a/motion/__pycache__/visual_api.cpython-310.pyc b/motion/__pycache__/visual_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e0b7889694796105e253510fc0e691c1738a219 Binary files /dev/null and b/motion/__pycache__/visual_api.cpython-310.pyc differ diff --git a/motion/__pycache__/visual_api.cpython-311.pyc b/motion/__pycache__/visual_api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..223bda3a58b15ea35b1e3e81948ecc10ef24008f Binary files /dev/null and b/motion/__pycache__/visual_api.cpython-311.pyc differ diff --git a/motion/__pycache__/visual_api.cpython-39.pyc b/motion/__pycache__/visual_api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a2d745055a2a16f192c9f730ae5757cb365f74 Binary files /dev/null and b/motion/__pycache__/visual_api.cpython-39.pyc differ diff --git a/motion/dataset/Mean.npy b/motion/dataset/Mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..35a4c9b8b10070212013885b979d9abb4978da41 --- /dev/null +++ b/motion/dataset/Mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4718db19ef1a19033f94becdce3b351fbb3ed4a9bfc7a8653c7cf4d0c8fa9696 +size 1180 diff --git a/motion/dataset/Mean_smr.npy b/motion/dataset/Mean_smr.npy new file mode 100644 index 0000000000000000000000000000000000000000..478c54a5544030aef608b6d07331e04f1dde647a --- /dev/null +++ b/motion/dataset/Mean_smr.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6b3d28ff33a650d68508c6ff46004810d76df298044339b5847e25de689c43a +size 2280 diff --git a/motion/dataset/Std.npy b/motion/dataset/Std.npy new file mode 100644 index 0000000000000000000000000000000000000000..6c7a179afdd4759ee623f7f817d2b1e890199fb2 --- /dev/null +++ b/motion/dataset/Std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43e126128c6fa871b55a6b0af75cd722bd715ee164dc5e2f4c5e000b0bd1ffc8 +size 1180 diff --git a/motion/dataset/Std_smr.npy b/motion/dataset/Std_smr.npy new file mode 100644 index 0000000000000000000000000000000000000000..a7586bbae540ed3405e552f13f2885b25d028ee0 --- /dev/null +++ b/motion/dataset/Std_smr.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f6643cdd1b819bb67e9d2537f170e58fce1f36a20ae8c4c0ae716874a9089a0 +size 2280 diff --git a/motion/dataset/__pycache__/paramUtil.cpython-310.pyc b/motion/dataset/__pycache__/paramUtil.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8ec5eb188b31165af937aad83952135b700ce83 Binary files /dev/null and b/motion/dataset/__pycache__/paramUtil.cpython-310.pyc differ diff --git a/motion/dataset/__pycache__/paramUtil.cpython-311.pyc b/motion/dataset/__pycache__/paramUtil.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1791b08ebd7c175aab0f49cc937a23e2a3135235 Binary files /dev/null and b/motion/dataset/__pycache__/paramUtil.cpython-311.pyc differ diff --git a/motion/dataset/__pycache__/paramUtil.cpython-39.pyc b/motion/dataset/__pycache__/paramUtil.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f29125764d89363d6ae3181ef2f2ecd75c68e00 Binary files /dev/null and b/motion/dataset/__pycache__/paramUtil.cpython-39.pyc differ diff --git a/motion/dataset/__pycache__/recover_joints.cpython-310.pyc b/motion/dataset/__pycache__/recover_joints.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b17202138e763b07d9a75a6f2f499b1c90dd0809 Binary files /dev/null and b/motion/dataset/__pycache__/recover_joints.cpython-310.pyc differ diff --git a/motion/dataset/__pycache__/recover_joints.cpython-311.pyc b/motion/dataset/__pycache__/recover_joints.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bcf353021b293c2692fa6e953020da0ca1be2d8 Binary files /dev/null and b/motion/dataset/__pycache__/recover_joints.cpython-311.pyc differ diff --git a/motion/dataset/__pycache__/recover_joints.cpython-39.pyc b/motion/dataset/__pycache__/recover_joints.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cbe143b6027b4426dfbb8f1acbcf1fadfa3db88 Binary files /dev/null and b/motion/dataset/__pycache__/recover_joints.cpython-39.pyc differ diff --git a/motion/dataset/__pycache__/recover_smr.cpython-310.pyc b/motion/dataset/__pycache__/recover_smr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e5d1d6496678a124a5b201aa48e31e1c073aaaa Binary files /dev/null and b/motion/dataset/__pycache__/recover_smr.cpython-310.pyc differ diff --git a/motion/dataset/__pycache__/recover_smr.cpython-311.pyc b/motion/dataset/__pycache__/recover_smr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07750129e029eaa95e8bd34ac763daf1bf1301c4 Binary files /dev/null and b/motion/dataset/__pycache__/recover_smr.cpython-311.pyc differ diff --git a/motion/dataset/__pycache__/recover_smr.cpython-39.pyc b/motion/dataset/__pycache__/recover_smr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07951273e6593d29f087ff44ebe8de1d20a29985 Binary files /dev/null and b/motion/dataset/__pycache__/recover_smr.cpython-39.pyc differ diff --git a/motion/dataset/paramUtil.py b/motion/dataset/paramUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f1708b85ca80a9051cb3675cec9b999a0d0e2b --- /dev/null +++ b/motion/dataset/paramUtil.py @@ -0,0 +1,63 @@ +import numpy as np + +# Define a kinematic tree for the skeletal struture +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_offsets = np.array([[0,0,0], + [1,0,0], + [-1,0,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,0,1], + [0,0,1], + [0,1,0], + [1,0,0], + [-1,0,0], + [0,0,1], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0]]) + +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + + +kit_tgt_skel_id = '03950' + +t2m_tgt_skel_id = '000021' + diff --git a/motion/dataset/recover_joints.py b/motion/dataset/recover_joints.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a72833efa15c70119a2da7282178cb0b7c9aa2 --- /dev/null +++ b/motion/dataset/recover_joints.py @@ -0,0 +1,188 @@ +import torch +# Recover global angle and positions for rotation data +# root_rot_velocity (B, seq_len, 1) +# root_linear_velocity (B, seq_len, 2) +# root_y (B, seq_len, 1) +# ric_data (B, seq_len, (joint_num - 1)*3) +# rot_data (B, seq_len, (joint_num - 1)*6) +# local_velocity (B, seq_len, joint_num*3) +# foot contact (B, seq_len, 4) +import numpy as np +from SMPLX.rotation_conversions import * + +def qinv(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def recover_root_rot_pos(data): + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + +def quaternion_to_cont6d(quaternions): + rotation_mat = quaternion_to_matrix(quaternions) + cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) + return cont_6d + +def recover_from_rot(data, joints_num, skeleton): + r_rot_quat, r_pos = recover_root_rot_pos(data) + + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + + start_indx = 1 + 2 + 1 + (joints_num - 1) * 3 + end_indx = start_indx + (joints_num - 1) * 6 + cont6d_params = data[..., start_indx:end_indx] + # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape) + cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + + positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos) + + return positions + + +def recover_from_ric(data, joints_num): + if isinstance(data, np.ndarray): + data = torch.from_numpy(data).float() + dtype = "numpy" + else: + data = data.float() + dtype = "tensor" + + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + + '''Add Y-axis rotation to local joints''' + positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) + + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + + '''Concate root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + if dtype == "numpy": + positions = positions.numpy() + + return positions + +def t2m_to_eval_rep(data, joint_num=22): + bs, nframes, length = data.shape + if isinstance(data, np.ndarray): + data = torch.from_numpy(data).float() + elif isinstance(data, torch.Tensor): + data = data.float() + joints = recover_from_ric(data, joint_num) + translation = joints[:, :, 0, :] - joints[:, 0:1, 0, :] ### [bs, nframes, 3] + + joints -= translation.unsqueeze(2) + joints = torch.cat([translation.unsqueeze(2), joints], dim=2) #### [bs, nframes, 23, 3] + data = joints.reshape(bs, nframes, -1).cpu().numpy() + return data + +def recover_pose_from_t2m(data, njoints=22): + joints = recover_from_ric(data, njoints) + trans = joints[:, 0, :] - joints[0:1, 0, :] + + pose = data[:, 4 + (njoints - 1) * 3:4 + (njoints - 1) * 9] + pose = pose.reshape(pose.shape[0], njoints-1, 6) + ptype = type(pose) + if ptype == np.ndarray: + pose = torch.from_numpy(pose).float() + pose = rotation_6d_to_matrix(pose) + pose = matrix_to_axis_angle(pose) + pose = pose.numpy() + root_vel = np.zeros([pose.shape[0], 1, 3]) + pose = np.concatenate([root_vel, pose], axis=1) + elif ptype == torch.Tensor: + pose = rotation_6d_to_matrix(pose) + pose = matrix_to_axis_angle(pose) + root_vel = torch.zeros([pose.shape[0], 1, 3]) + pose = torch.cat([root_vel, pose], dim=1) + + pose = pose.reshape(pose.shape[0], -1) + + if njoints < 24: + if ptype == np.ndarray: + addition = np.zeros([pose.shape[0], 72-njoints*3]) + pose = np.concatenate([pose, addition], axis=1) + elif ptype == torch.Tensor: + addition = torch.zeros([pose.shape[0], 72-njoints*3], dtype=pose.dtype, device=pose.device) + pose = torch.cat([pose, addition], dim=1) + + if ptype == np.ndarray: + pose = np.concatenate([pose, trans], axis=1) + elif ptype == torch.Tensor: + pose = torch.cat([pose, trans], dim=1) + + return pose diff --git a/motion/dataset/recover_smr.py b/motion/dataset/recover_smr.py new file mode 100644 index 0000000000000000000000000000000000000000..0efba1294345f4f827581fd8ac9e310b5f7c4265 --- /dev/null +++ b/motion/dataset/recover_smr.py @@ -0,0 +1,111 @@ +import numpy as np +import torch +from SMPLX.rotation_conversions import rotation_6d_to_matrix, matrix_to_axis_angle + +def qinv(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def recover_root_rot_pos(data): + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + +def recover_from_ric(data, joints_num): + if isinstance(data, np.ndarray): + data = torch.from_numpy(data).float() + dtype = "numpy" + else: + data = data.float() + dtype = "tensor" + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + + '''Add Y-axis rotation to local joints''' + positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) + + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + + '''Concate root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + if dtype == "numpy": + positions = positions.numpy() + + return positions + +def recover_pose_from_smr(data, njoints=22): + joints = recover_from_ric(data, njoints) + trans = joints[:, 0, :] - joints[0:1, 0, :] + + pose = data[:, 4 + (njoints - 1) * 3:10 + (njoints - 1) * 9] + pose = pose.reshape(pose.shape[0], njoints, 6) + ptype = type(pose) + if ptype == np.ndarray: + pose = torch.from_numpy(pose).float() + pose = rotation_6d_to_matrix(pose) + pose = matrix_to_axis_angle(pose) + pose = pose.numpy() + elif ptype == torch.Tensor: + pose = rotation_6d_to_matrix(pose) + pose = matrix_to_axis_angle(pose) + + pose = pose.reshape(pose.shape[0], -1) + + if njoints < 24: + if ptype == np.ndarray: + addition = np.zeros([pose.shape[0], 72-njoints*3]) + pose = np.concatenate([pose, addition], axis=1) + elif ptype == torch.Tensor: + addition = torch.zeros([pose.shape[0], 72-njoints*3], dtype=pose.dtype, device=pose.device) + pose = torch.cat([pose, addition], dim=1) + + if ptype == np.ndarray: + pose = np.concatenate([pose, trans], axis=1) + elif ptype == torch.Tensor: + pose = torch.cat([pose, trans], dim=1) + + return pose \ No newline at end of file diff --git a/motion/dataset/smplh.faces b/motion/dataset/smplh.faces new file mode 100644 index 0000000000000000000000000000000000000000..20c84e393f281fa5d98769a90a33a89cd00d8fea Binary files /dev/null and b/motion/dataset/smplh.faces differ diff --git a/motion/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc b/motion/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a05f7f72db9bac013eb93dccb86ca2fa05ba89f Binary files /dev/null and b/motion/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc differ diff --git a/motion/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc b/motion/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08674ac5bfe10e56c085b73fbfbbc4b83c9e57a9 Binary files /dev/null and b/motion/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc differ diff --git a/motion/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc b/motion/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c770839be796be6f27523c760ef37285c6b56bcb Binary files /dev/null and b/motion/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc differ diff --git a/motion/diffusion/__pycache__/nn.cpython-310.pyc b/motion/diffusion/__pycache__/nn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..065bc095ca1923ae4c40102c71d8116ce76a4054 Binary files /dev/null and b/motion/diffusion/__pycache__/nn.cpython-310.pyc differ diff --git a/motion/diffusion/__pycache__/nn.cpython-311.pyc b/motion/diffusion/__pycache__/nn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a8affc0a3c1a5c4270a6afa89e0a68aa72109d9 Binary files /dev/null and b/motion/diffusion/__pycache__/nn.cpython-311.pyc differ diff --git a/motion/diffusion/__pycache__/nn.cpython-39.pyc b/motion/diffusion/__pycache__/nn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8668bfd4434ae4d5ae380cf4fe56091f9f460679 Binary files /dev/null and b/motion/diffusion/__pycache__/nn.cpython-39.pyc differ diff --git a/motion/diffusion/__pycache__/respace.cpython-310.pyc b/motion/diffusion/__pycache__/respace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fad272e7c67fd3bc6617df5414d0fe5419837bc3 Binary files /dev/null and b/motion/diffusion/__pycache__/respace.cpython-310.pyc differ diff --git a/motion/diffusion/__pycache__/respace.cpython-311.pyc b/motion/diffusion/__pycache__/respace.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1534e129b85c1b5d698255db5c18f2062679d5b Binary files /dev/null and b/motion/diffusion/__pycache__/respace.cpython-311.pyc differ diff --git a/motion/diffusion/__pycache__/respace.cpython-39.pyc b/motion/diffusion/__pycache__/respace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..640ae943096c0bd036bd68eff4663629e5d716a8 Binary files /dev/null and b/motion/diffusion/__pycache__/respace.cpython-39.pyc differ diff --git a/motion/diffusion/gaussian_diffusion.py b/motion/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..a83230409a01ab61e1ad417aab7a037364569666 --- /dev/null +++ b/motion/diffusion/gaussian_diffusion.py @@ -0,0 +1,651 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +import enum +import math + +import numpy as np +import torch +import torch as th +from copy import deepcopy +from motion.diffusion.nn import sum_flat +from motion.dataset.recover_smr import * +from SMPLX.rotation_conversions import rotation_6d_to_matrix, matrix_to_axis_angle +# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, scale_betas=1.): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = scale_betas * 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ### t=0->1, t=1->0, t=2->1, t=3->0, 近似于 0,1 交替输入 + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + rep="t2m" + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + self.rep = rep + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) #### 累乘变成 alpha_bar + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ### append 是合并, 意思是倒序排列,但是去掉把第一个换成 1 + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) #### 正序排列,但是把第一个换成 0 并插到最后 + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( ###### 计算 \mu(xt, x0) 的一部分 + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. + + def masked_l2(self, a, b, mask, addition_rotate_mask): + loss = self.l2_loss(a, b) #### [bs, 263, 1, num_frames] + loss = sum_flat(loss * mask.float() * addition_rotate_mask.float()) # gives \sigma_euclidean over unmasked elements ### [Batch] + + n_entries = a.shape[1] * a.shape[2] ##### BS * 263 * 1 * num_frame -> 263 + non_zero_elements = sum_flat(mask) * n_entries + mse_loss_val = loss / non_zero_elements + return mse_loss_val + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None, model_kwargs=None): + """ + Diffuse the dataset for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial dataset batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + + return ( ######### 前向传播 xt = self.sqrt_alphas_cumprod[t] * x0 + self.sqrt_one_minus_alphas_cumprod[t] * \epsilon + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + model_output = model_output["output"] + + x_t = x + + if 'inpainting_mask' in model_kwargs['y'].keys() and 'inpainted_motion' in model_kwargs['y'].keys(): + inpainting_mask, inpainted_motion = model_kwargs['y']['inpainting_mask'], model_kwargs['y']['inpainted_motion'] + assert self.model_mean_type == ModelMeanType.START_X, 'This feature supports only X_start pred for mow!' + assert model_output.shape == inpainting_mask.shape == inpainted_motion.shape + + ones = torch.ones_like(inpainting_mask, dtype=torch.float, device=inpainting_mask.device) + inpainting_mask = ones * inpainting_mask + model_output = (model_output * (1 - inpainting_mask)) + (inpainted_motion * inpainting_mask) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( ############ USE IT + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + + model_variance = _extract_into_tensor(model_variance, t, x_t.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x_t.shape) + + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + # print('clip_denoised', clip_denoised) + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x_t, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: # THIS IS US! + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x_t, t=t, eps=model_output)) + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x_t, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert (model_mean.shape == model_log_variance.shape == pred_xstart.shape == x_t.shape) + + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + const_noise=False, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, #### x 列表 + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + + noise = th.randn_like(out["mean"]) + if const_noise: + noise = noise[[0]].repeat(out["mean"].shape[0], 1, 1, 1) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(out["mean"].shape) - 1)))) # no noise when t == 0 + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise ## \mu + nonzero_mask * \std * noise + + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + dump_steps=None, + const_noise=False, + unfolding_handshake=0, # 0 means no unfolding + eval_mask=None + + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :param const_noise: If True, will noise all samples with the same noise throughout sampling + :return: a non-differentiable batch of samples. + """ + final = None + if dump_steps is not None: + dump = [] + + for i, sample in enumerate(self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + const_noise=const_noise, + eval_mask=eval_mask + )): + # unfolding + if unfolding_handshake > 0: + ''' + first take 点这里 + ''' + alpha = torch.arange(0, unfolding_handshake, 1, device=sample['sample'].device) / unfolding_handshake + for sample_i, len in zip(range(1, sample['sample'].shape[0]), model_kwargs['y']['lengths']): + _suffix = sample['sample'][sample_i - 1, :, :, -unfolding_handshake + len:len] + _prefix = sample['sample'][sample_i, :, :, :unfolding_handshake] + try: + _blend = (_suffix * (1 - alpha) + _prefix * alpha) + except(RuntimeError): + print("Error") + sample['sample'][sample_i - 1, :, :, -unfolding_handshake + len:len] = _blend #### 混合操作,保证下一帧的 left = 这一帧的 right, 这样 double take 的时候才能直接用 right 覆盖 left + sample['sample'][sample_i, :, :, :unfolding_handshake] = _blend + + if dump_steps is not None and i in dump_steps: + dump.append(deepcopy(sample["sample"])) + final = sample + if dump_steps is not None: + return dump + + res = {"output":final["sample"]} + return res + + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + const_noise=False, + eval_mask=None + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] #### [999, 998, ... 0] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img, model_kwargs=model_kwargs) + ''' + 把 eval_mask 放在这里相当于初始化时若干帧的结果存在问题 + 如果把 eval_mask 放在循环中, 就相当于推理过程中指定位置一直在生成不同的错误帧 + ''' + if eval_mask is not None and img.shape[0] != 1: + rand_img = torch.randperm(img.shape[0]) + rand_img = img[rand_img] + img = img * (1 - eval_mask) + rand_img * eval_mask + elif eval_mask is not None and img.shape[0] == 1: + rand_img = th.randn(*shape, device=device) + img = img * (1 - eval_mask) + rand_img * eval_mask + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) ### t = [999] + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint(low=0, high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + sample_fn = self.p_sample + condition = deepcopy(model_kwargs) + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=condition, + const_noise=const_noise, + ) + + yield out + img = out["sample"] ##### 最开始是随机噪声,然后会得到 999 的输出,然后得到 998 的输出,最后一步是预测的 x0 + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. 生成目标 x0 + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + + + mask = model_kwargs['y']['mask'] + + if len(x_start.shape) == 3: + x_start = x_start.permute(0, 2, 1).unsqueeze(2) + elif len(x_start.shape) == 4: + x_start = x_start.permute(0, 2, 3, 1) + + if self.rep == "smplx": + addition_rotate_mask = torch.ones_like(x_start) + # addition_rotate_mask = mask.repeat(1, x_start.shape[1], x_start.shape[2], 1) ### [bs, njoints, nfeats, nframes] + # speed = x_start[..., 1::] - x_start[..., :-1] #### [bs, njoints, nfeats, nframes-1] + # speed = speed.sum(dim=-1).sum(dim=-1) #### [bs, njoints] + # nosub = speed == 0 #### find joints that have no change between different frames and not calculate loss function + # addition_rotate_mask[nosub] = 0 + else: + addition_rotate_mask = torch.ones_like(x_start) + + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise, model_kwargs=model_kwargs) ###### 前向传播 x0 到 xt + + terms = {} + + if self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: #### 默认用 mse 损失 + + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) #### mixup_res + + model_output = model_output["output"] #### [bs, 263, 1, nframes] -> [nfrmaes, bs, 512] -> [bs, 263, 1, nframes] + + if self.model_mean_type == ModelMeanType.START_X: + target = x_start + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + elif self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] + + assert model_output.shape == target.shape == x_start.shape + + terms["rot_mse"] = self.masked_l2(target, model_output, mask, addition_rotate_mask=addition_rotate_mask) + + terms["loss"] = terms["rot_mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/motion/diffusion/nn.py b/motion/diffusion/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..41c18e7dd3d8cae1e719638e87c27f718f6a94e6 --- /dev/null +++ b/motion/diffusion/nn.py @@ -0,0 +1,197 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Various utilities for neural networks. +""" + +import math + +import torch as th +import torch.nn as nn + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +def sum_flat(tensor): + """ + Take the sum over all non-batch dimensions. + """ + return tensor.sum(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp( + -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(th.autograd.Function): + @staticmethod + @th.cuda.amp.custom_fwd + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_length = length + ctx.save_for_backward(*args) + with th.no_grad(): + output_tensors = ctx.run_function(*args[:length]) + return output_tensors + + @staticmethod + @th.cuda.amp.custom_bwd + def backward(ctx, *output_grads): + args = list(ctx.saved_tensors) + + # Filter for inputs that require grad. If none, exit early. + input_indices = [i for (i, x) in enumerate(args) if x.requires_grad] + if not input_indices: + return (None, None) + tuple(None for _ in args) + + with th.enable_grad(): + for i in input_indices: + if i < ctx.input_length: + # Not sure why the OAI code does this little + # dance. It might not be necessary. + args[i] = args[i].detach().requires_grad_() + args[i] = args[i].view_as(args[i]) + output_tensors = ctx.run_function(*args[:ctx.input_length]) + + if isinstance(output_tensors, th.Tensor): + output_tensors = [output_tensors] + + # Filter for outputs that require grad. If none, exit early. + out_and_grads = [(o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad] + if not out_and_grads: + return (None, None) + tuple(None for _ in args) + + # Compute gradients on the filtered tensors. + computed_grads = th.autograd.grad( + [o for (o, g) in out_and_grads], + [args[i] for i in input_indices], + [g for (o, g) in out_and_grads] + ) + + # Reassemble the complete gradient tuple. + input_grads = [None for _ in args] + for (i, g) in zip(input_indices, computed_grads): + input_grads[i] = g + return (None, None) + tuple(input_grads) diff --git a/motion/diffusion/resample.py b/motion/diffusion/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..1be123e22f0ca90b61c1bdbe13ef14a8199f9045 --- /dev/null +++ b/motion/diffusion/resample.py @@ -0,0 +1,155 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long() + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float() + + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/motion/diffusion/respace.py b/motion/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..277895b5de81fdbfbf8c517f1e148d079a3dc871 --- /dev/null +++ b/motion/diffusion/respace.py @@ -0,0 +1,183 @@ +# This code is based on https://github.com/openai/guided-diffusion +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion +from .gaussian_diffusion import _extract_into_tensor + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + + +class InpaintingGaussianDiffusion(SpacedDiffusion): + def q_sample(self, x_start, t, noise=None, model_kwargs=None): + """ + overrides q_sample to use the inpainting mask + + same usage as in GaussianDiffusion + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + + bs, feat, _, frames = noise.shape + noise *= 1. - model_kwargs['y']['inpainting_mask'] + + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + const_noise=False, + ): + """ + overrides p_sample to use the inpainting mask + + same usage as in GaussianDiffusion + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + if const_noise: + noise = noise[[0]].repeat(x.shape[0], 1, 1, 1) + noise *= 1. - model_kwargs['y']['inpainting_mask'] + + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} \ No newline at end of file diff --git a/motion/double_take.py b/motion/double_take.py new file mode 100644 index 0000000000000000000000000000000000000000..ba9a21245a48863ace30bb412a2a96b34ce9ed5e --- /dev/null +++ b/motion/double_take.py @@ -0,0 +1,270 @@ +from copy import deepcopy +import torch +import pandas as pd +import numpy as np + +def pad_sample_with_zeros(sample, max_len=250): + # pad inp, change lenghts, and pad is transition + seq_len, n_feats = sample.shape + len_to_pad = max_len - seq_len + np.zeros_like(sample) + sample_padding = np.zeros((len_to_pad, n_feats)) + sample = np.concatenate((sample, sample_padding)) + return sample + +def split2subs(motions, step_sizes, batch_size, blend_len, max_motion_length): + #### motions [1, 263, 1, Nlength] -> [263, Nlength] -> [NLength, 263] + new_motions = [] + new_lengths = [] + new_motions.append(pad_sample_with_zeros(motions[..., :step_sizes[0] - blend_len].squeeze().permute(1, 0).cpu().numpy(), max_motion_length)) + new_lengths.append(step_sizes[0] - blend_len) + for i in range(1, batch_size-1): + curr = pad_sample_with_zeros(motions[..., step_sizes[i-1]-blend_len:step_sizes[i]-blend_len].squeeze().permute(1, 0).cpu().numpy(), max_motion_length) + new_motions.append(curr) + new_lengths.append(step_sizes[i] - step_sizes[i-1]) + + new_motions.append(pad_sample_with_zeros(motions[..., step_sizes[-1]-blend_len:].squeeze().permute(1, 0).cpu().numpy(), max_motion_length)) + new_lengths.append(step_sizes[-1]-step_sizes[-2]+blend_len) + + new_motions = np.stack(new_motions, axis=0) + new_motions = torch.from_numpy(new_motions) + new_lengths = np.stack(new_lengths, axis=0) + new_lengths = torch.from_numpy(new_lengths).long() + return new_motions, new_lengths + + +def unfold_sample_arb_len(sample, handshake_size, step_sizes, final_n_frames, model_kwargs): + old_sample = deepcopy(sample) + new_shape = list(old_sample.shape) + new_shape[0] = 1 + new_shape[-1] = final_n_frames + sample = torch.zeros(new_shape, dtype=sample.dtype, device=sample.device) + sample[0, :, :, :model_kwargs['y']['lengths'][0]] = old_sample[0, :, :, :model_kwargs['y']['lengths'][0]] + for sample_i, len_i in enumerate(step_sizes): + if sample_i == 0: + continue + start = step_sizes[sample_i-1] + sample[0, :, :, start:len_i] = old_sample[sample_i, :, :, handshake_size:model_kwargs['y']['lengths'][sample_i]] + return sample + + +def double_take_arb_len(diffusion, model, model_kwargs, n_frames, blend_len=10, handshake_size=20, device="cpu", progress=True): + sample_fn = diffusion.p_sample_loop + blend_len = blend_len + handshake_size = handshake_size + + batch_size = len(model_kwargs['y']['text']) + + # Unfolding - orig + sample = sample_fn( + model, + (batch_size, model.njoints, model.nfeats, n_frames), + clip_denoised=False, + model_kwargs=model_kwargs, + skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=progress, + dump_steps=None, + noise=None, + const_noise=False, + unfolding_handshake=handshake_size, + ) + + model_kwargs['y']['scale'] = torch.ones(batch_size-1, device=device) * 0 + sample = sample["output"] #### [5, 263, 1 196] + + ''' + 1. 替换 sample + 2. model_kwargs['y']['lengths'] + ''' + + new_sample_seq_len = (sample.shape[-1] - 2 * handshake_size) * 2 + handshake_size + + bs, feats, joints, seq_len = sample.shape + new_sample = torch.zeros((bs-1, feats, joints, new_sample_seq_len), dtype=sample.dtype, device=sample.device) + + generated_motion = [] + right_constraint = [] + left_constraint = [] + + for ii in range(bs): ##### 按左中右拆分 Motion + generated_motion.append(deepcopy(sample[ii, :, :, handshake_size: model_kwargs['y']['lengths'][ii]-handshake_size])) # w/o start and end + left_constraint.append(deepcopy(sample[ii, :, :, :handshake_size])) # left side + right_constraint.append(deepcopy(sample[ii, :, :, model_kwargs['y']['lengths'][ii] - handshake_size: model_kwargs['y']['lengths'][ii]])) + + buffer = [] #### 存放剩下的动作部分的长度,也就是 generated_motion 的长度 + for ii in range(bs): + buffer.append(int(model_kwargs['y']['lengths'][ii]) - 2*handshake_size) + for ii in range(bs - 1): # run over bs, 把 N句话 合并成 N-1 句话,新 motion 的组成 [gm[i-1], right[i-1], gm[i]], 长度是 2 * gm_length + hand_size + new_sample[ii, :, :, :buffer[ii]] = generated_motion[ii] + new_sample[ii, :, :, buffer[ii]: buffer[ii]+handshake_size] = right_constraint[ii] # add transition + new_sample[ii, :, :, buffer[ii]+handshake_size : buffer[ii]+handshake_size+buffer[ii+1]] = generated_motion[ii + 1] + + # "in between" + model_kwargs['y']['inpainted_motion'] = new_sample + model_kwargs['y']['inpainting_mask'] = torch.ones_like(new_sample, dtype=torch.float, + device=new_sample.device) + + for ii in range(bs - 1): # run over bs + if blend_len >= 2: + ''' + 渐变混合 + 1. 在左边 gm[i-1] 靠后 blend_len 的区域,渐变地保留原本的内容 + 2. 在右边 gm[i] 靠前的 blend_len 的区域,渐变保留原本的内容 + 3. 似乎是 right 的部分完全保留,也就是用前一个动作的结束座位后一个动作的开头 + ''' + + model_kwargs['y']['inpainting_mask'][ii, :, :, buffer[ii] - blend_len: buffer[ii]] = \ + torch.arange(0.85, 0.0, -0.85 / int(blend_len)) + model_kwargs['y']['inpainting_mask'][ii, :, :, buffer[ii] + handshake_size: buffer[ii] + handshake_size + blend_len] = \ + torch.arange(0.0, 0.85, 0.85 / int(blend_len)) + + model_kwargs['y']['uncond'] = 1.0 ### 混合多段语意后,cond 没什么意义,而且需要生成的内容很少 + model_kwargs['y']['text'] = model_kwargs['y']['text'][:bs-1] + sample_fn = diffusion.p_sample_loop # double take sample function + n_frames = new_sample_seq_len + orig_lens = deepcopy(model_kwargs['y']['lengths']) + for ii in range (len(model_kwargs['y']['lengths'])-1): + model_kwargs['y']['lengths'][ii] = model_kwargs['y']['lengths'][ii] + model_kwargs['y']['lengths'][ii+1] - 3*handshake_size + model_kwargs['y']['lengths'] = model_kwargs['y']['lengths'][:-1] + + double_take_sample = sample_fn( + model, + (batch_size-1, model.njoints, model.nfeats, n_frames), + clip_denoised=False, + model_kwargs=model_kwargs, + skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + init_image=new_sample, #TODO!! check if plausible or not! + progress=progress, + dump_steps=None, + noise=None, + const_noise=False, + ) + double_take_sample = double_take_sample["output"] + model_kwargs['y']['lengths'] = orig_lens + # rebuild_orig: + rebuild_sample = torch.zeros_like(sample) + + ''' + sample -> left + motion + right + double_take_sample -> motion1 + blend + hand + blend + motion2, 其中长度表示 : motion1 + blend = motion2 + blend = motion + ''' + + transitions, right_side, left_side = [], [], [] + for ii in range(bs - 1): # run over bs + transitions.append(double_take_sample[ii, :, :, buffer[ii]: buffer[ii]+handshake_size]) + right_side.append(double_take_sample[ii, :, :, buffer[ii] + handshake_size: buffer[ii] + handshake_size + blend_len]) # M1 blending.. + left_side.append(double_take_sample[ii, :, :, buffer[ii] - blend_len:buffer[ii]]) # M0 blending... + + ''' + translation 储存的是 hand + right_side 存右边的 blend + left_side 村左边的 blend + ''' + + + rebuild_sample[0, :, :, :handshake_size] = left_constraint[0] # Fill missing + rebuild_sample[-1, :, :, buffer[-1]+handshake_size: buffer[-1]+2*handshake_size] = right_constraint[-1] # Fill missing + + ''' + 展开 double take 的结果, 还原会原本的状态,即 left + motion + right + ''' + + for ii in range(bs - 1): + rebuild_sample[ii + 1, :, :, :handshake_size] = transitions[ii] + rebuild_sample[ii, :, :, handshake_size: buffer[ii]+handshake_size] = generated_motion[ii] + rebuild_sample[ii, :, :, buffer[ii]+handshake_size: buffer[ii]+2*handshake_size] = transitions[ii] #### motion1 的 right = motion2 的 left + rebuild_sample[ii, :, :, handshake_size + buffer[ii]-blend_len: handshake_size + buffer[ii]] = left_side[ii] + # if ii > 0: + rebuild_sample[-1, :, :, handshake_size: buffer[-1] + handshake_size] = generated_motion[-1] + for ii in range(bs - 1): + rebuild_sample[ii+1, :, :, handshake_size:handshake_size + blend_len] = right_side[ii] + + double_take_sample = deepcopy(rebuild_sample) + + return double_take_sample + +def double_take(prompt=None, path=None, num_repetitions=1, model=None, diffusion=None, handshake_size=20, blend_len=10, default_length=196, guidance_param=2.5, device="cpu", progress=True): + assert model is not None + assert diffusion is not None + if prompt is not None: + texts = prompt.split("|") + num_samples = len(texts) + length = [] + captions = [] + for i in range(len(texts)): + nframes = texts[i].split(",")[0] + try: + nframes = int(nframes) + curr_text = texts[i].split(",")[1::] + curr_text = ",".join(curr_text) + except: + nframes = default_length + curr_text = texts[i] + + captions.append(curr_text) + length.append(nframes) + + model_kwargs = {'y': { + 'mask': torch.ones((len(texts), 1, 1, default_length)), # 196 is humanml max frames number + 'lengths': torch.tensor(length), + 'text': captions, + 'tokens': [''], + 'scale': torch.ones(len(texts))*guidance_param + }} + elif path.split(".")[-1] == "csv": + df = pd.read_csv(path) + num_samples = len(list(df['text'])) + model_kwargs = {'y': { + 'mask': torch.ones((len(list(df['text'])), 1, 1, default_length)), #196 is humanml max frames number + 'lengths': torch.tensor(list(df['length'])), + 'text': list(df['text']), + 'tokens': [''], + 'scale': torch.ones(len(list(df['text'])))*guidance_param + }} + elif path.split(".")[-1] == "txt": + with open(path, 'r') as fr: + texts = fr.readlines() + texts = [s.replace('\n', '') for s in texts] + num_samples = len(texts) + model_kwargs = {'y': { + 'mask': torch.ones((len(texts), 1, 1, default_length)), # 196 is humanml max frames number + 'lengths': torch.tensor([default_length]*len(texts)), + 'text': texts, + 'tokens': [''], + 'scale': torch.ones(len(texts))*guidance_param + }} + + all_motions = [] + + for rep_i in range(num_repetitions): + if guidance_param != 1: + model_kwargs['y']['scale'] = torch.ones(num_samples, device=device) * guidance_param + model_kwargs['y'] = {key: val.to(device) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()} + + max_arb_len = model_kwargs['y']['lengths'].max() + min_arb_len = 2 * handshake_size + 2*blend_len + 10 + + for ii, len_s in enumerate(model_kwargs['y']['lengths']): + if len_s > max_arb_len: + model_kwargs['y']['lengths'][ii] = max_arb_len + if len_s < min_arb_len: + model_kwargs['y']['lengths'][ii] = min_arb_len + + sample = double_take_arb_len(diffusion, model, model_kwargs, max_arb_len, blend_len, handshake_size, device, progress=progress) + step_sizes = np.zeros(len(model_kwargs['y']['lengths']), dtype=int) + for ii, len_i in enumerate(model_kwargs['y']['lengths']): + if ii == 0: + step_sizes[ii] = len_i + continue + step_sizes[ii] = step_sizes[ii-1] + len_i - handshake_size + + final_n_frames = step_sizes[-1] + sample = unfold_sample_arb_len(sample, handshake_size, step_sizes, final_n_frames, model_kwargs) + + all_motions.append(sample) + + all_motions = torch.cat(all_motions, dim=0) + return all_motions, step_sizes + + \ No newline at end of file diff --git a/motion/hybrik_loc2rot.py b/motion/hybrik_loc2rot.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e8e4702886439034c1b4e55cf2a4c98461d134 --- /dev/null +++ b/motion/hybrik_loc2rot.py @@ -0,0 +1,140 @@ +import numpy as np + +SMPL_BODY_BONES = [-0.0018, -0.2233, 0.0282, 0.0695, -0.0914, -0.0068, -0.0677, -0.0905, -0.0043, + -0.0025, 0.1090, -0.0267, 0.0343, -0.3752, -0.0045, -0.0383, -0.3826, -0.0089, + 0.0055, 0.1352, 0.0011, -0.0136, -0.3980, -0.0437, 0.0158, -0.3984, -0.0423, + 0.0015, 0.0529, 0.0254, 0.0264, -0.0558, 0.1193, -0.0254, -0.0481, 0.1233, + -0.0028, 0.2139, -0.0429, 0.0788, 0.1217, -0.0341, -0.0818, 0.1188, -0.0386, + 0.0052, 0.0650, 0.0513, 0.0910, 0.0305, -0.0089, -0.0960, 0.0326, -0.0091, + 0.2596, -0.0128, -0.0275, -0.2537, -0.0133, -0.0214, 0.2492, 0.0090, -0.0012, + -0.2553, 0.0078, -0.0056, 0.0840, -0.0082, -0.0149, -0.0846, -0.0061, -0.0103] + + +class HybrIKJointsToRotmat: + def __init__(self): + self.naive_hybrik = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] + self.num_nodes = 22 + self.parents = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19] + self.child = [-1, 4, 5, 6, 7, 8, 9, 10, 11, -1, -2, -2, 15, + 16, 17, -2, 18, 19, 20, 21, -2, -2] + self.bones = np.reshape(np.array(SMPL_BODY_BONES), [24, 3])[:self.num_nodes] + + def multi_child_rot(self, t, p, + pose_global_parent): + """ + t: B x 3 x child_num + p: B x 3 x child_num + pose_global_parent: B x 3 x 3 + """ + m = np.matmul(t, np.transpose(np.matmul(np.linalg.inv(pose_global_parent), p), [0, 2, 1])) + u, s, vt = np.linalg.svd(m) + r = np.matmul(np.transpose(vt, [0, 2, 1]), np.transpose(u, [0, 2, 1])) + err_det_mask = (np.linalg.det(r) < 0.0).reshape(-1, 1, 1) + id_fix = np.reshape(np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]), + [1, 3, 3]) + r_fix = np.matmul(np.transpose(vt, [0, 2, 1]), + np.matmul(id_fix, + np.transpose(u, [0, 2, 1]))) + r = r * (1.0 - err_det_mask) + r_fix * err_det_mask + return r, np.matmul(pose_global_parent, r) + + def single_child_rot(self, t, p, pose_global_parent, twist=None): + """ + t: B x 3 x 1 + p: B x 3 x 1 + pose_global_parent: B x 3 x 3 + twist: B x 2 if given, default to None + """ + p_rot = np.matmul(np.linalg.inv(pose_global_parent), p) + cross = np.cross(t, p_rot, axisa=1, axisb=1, axisc=1) + sina = np.linalg.norm(cross, axis=1, keepdims=True) / (np.linalg.norm(t, axis=1, keepdims=True) * + np.linalg.norm(p_rot, axis=1, keepdims=True)) + cross = cross / np.linalg.norm(cross, axis=1, keepdims=True) + cosa = np.sum(t * p_rot, axis=1, keepdims=True) / (np.linalg.norm(t, axis=1, keepdims=True) * + np.linalg.norm(p_rot, axis=1, keepdims=True)) + sina = np.reshape(sina, [-1, 1, 1]) + cosa = np.reshape(cosa, [-1, 1, 1]) + skew_sym_t = np.stack([0.0 * cross[:, 0], -cross[:, 2], cross[:, 1], + cross[:, 2], 0.0 * cross[:, 0], -cross[:, 0], + -cross[:, 1], cross[:, 0], 0.0 * cross[:, 0]], 1) + skew_sym_t = np.reshape(skew_sym_t, [-1, 3, 3]) + dsw_rotmat = np.reshape(np.eye(3), [1, 3, 3] + ) + sina * skew_sym_t + (1.0 - cosa) * np.matmul(skew_sym_t, + skew_sym_t) + if twist is not None: + skew_sym_t = np.stack([0.0 * t[:, 0], -t[:, 2], t[:, 1], + t[:, 2], 0.0 * t[:, 0], -t[:, 0], + -t[:, 1], t[:, 0], 0.0 * t[:, 0]], 1) + skew_sym_t = np.reshape(skew_sym_t, [-1, 3, 3]) + sina = np.reshape(twist[:, 1], [-1, 1, 1]) + cosa = np.reshape(twist[:, 0], [-1, 1, 1]) + dtw_rotmat = np.reshape(np.eye(3), [1, 3, 3] + ) + sina * skew_sym_t + (1.0 - cosa) * np.matmul(skew_sym_t, + skew_sym_t) + dsw_rotmat = np.matmul(dsw_rotmat, dtw_rotmat) + return dsw_rotmat, np.matmul(pose_global_parent, dsw_rotmat) + + def __call__(self, joints, twist=None): + """ + joints: B x N x 3 + twist: B x N x 2 if given, default to None + """ + expand_dim = False + if len(joints.shape) == 2: + expand_dim = True + joints = np.expand_dims(joints, 0) + if twist is not None: + twist = np.expand_dims(twist, 0) + assert (len(joints.shape) == 3) + batch_size = np.shape(joints)[0] + joints_rel = joints - joints[:, self.parents] + joints_hybrik = 0.0 * joints_rel + pose_global = np.zeros([batch_size, self.num_nodes, 3, 3]) + pose = np.zeros([batch_size, self.num_nodes, 3, 3]) + for i in range(self.num_nodes): + if i == 0: + joints_hybrik[:, 0] = joints[:, 0] + else: + joints_hybrik[:, i] = np.matmul(pose_global[:, self.parents[i]], + np.reshape(self.bones[i], [1, 3, 1])).reshape(-1, 3) + \ + joints_hybrik[:, self.parents[i]] + if self.child[i] == -2: + pose[:, i] = pose[:, i] + np.eye(3).reshape(1, 3, 3) + pose_global[:, i] = pose_global[:, self.parents[i]] + continue + if i == 0: + r, rg = self.multi_child_rot(np.transpose(self.bones[[1, 2, 3]].reshape(1, 3, 3), [0, 2, 1]), + np.transpose(joints_rel[:, [1, 2, 3]], [0, 2, 1]), + np.eye(3).reshape(1, 3, 3)) + + elif i == 9: + r, rg = self.multi_child_rot(np.transpose(self.bones[[12, 13, 14]].reshape(1, 3, 3), [0, 2, 1]), + np.transpose(joints_rel[:, [12, 13, 14]], [0, 2, 1]), + pose_global[:, self.parents[9]]) + else: + p = joints_rel[:, self.child[i]] + if self.naive_hybrik[i] == 0: + p = joints[:, self.child[i]] - joints_hybrik[:, i] + twi = None + if twist is not None: + twi = twist[:, i] + r, rg = self.single_child_rot(self.bones[self.child[i]].reshape(1, 3, 1), + p.reshape(-1, 3, 1), + pose_global[:, self.parents[i]], + twi) + pose[:, i] = r + pose_global[:, i] = rg + if expand_dim: + pose = pose[0] + return pose + + +if __name__ == "__main__": + jts2rot_hybrik = HybrIKJointsToRotmat() + joints = np.array(SMPL_BODY_BONES).reshape(1, 24, 3)[:, :22] + parents = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19] + for i in range(1, 22): + joints[:, i] = joints[:, i] + joints[:, parents[i]] + pose = jts2rot_hybrik(joints) + print(pose) \ No newline at end of file diff --git a/motion/model/Encode_Full.py b/motion/model/Encode_Full.py new file mode 100644 index 0000000000000000000000000000000000000000..522b5ebb167dbeebc5340a4468fc2bc286721cae --- /dev/null +++ b/motion/model/Encode_Full.py @@ -0,0 +1,80 @@ +from torch import nn +import torch +import torch.nn.functional as F +from motion.model.layer_norm_fp16 import RMSNorm, LayerNorm + +class ResConv1DBlock(nn.Module): + def __init__(self, n_in, n_state, bias, norm_type, activate_type): + super().__init__() + + if activate_type.lower() == "silu": + activate = nn.SiLU() + elif activate_type.lower() == "relu": + activate = nn.ReLU() + elif activate_type.lower() == "gelu": + activate = nn.GELU() + elif activate_type.lower() == "mish": + activate = nn.Mish() + + if norm_type.lower() == "rmsnorm": + norm = RMSNorm + elif norm_type.lower() == "layernorm": + norm = LayerNorm + + self.norm1 = norm(n_state) + self.norm2 = norm(n_in) + self.relu1 = activate + self.relu2 = activate + self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0, bias=bias) + + def forward(self, x): + x_orig = x + x = self.conv1(x) + x = self.norm1(x.transpose(-2, -1)) + x = self.relu1(x.transpose(-2, -1)) + + x = self.conv2(x) + x = self.norm2(x.transpose(-2, -1)) + x = self.relu2(x.transpose(-2, -1)) + + x = x + x_orig + return x + +class Encoder_Block(nn.Module): + def __init__(self, begin_channel=263, latent_dim=512, num_layers=6, TN=1, bias=False, norm_type="rmsnorm", activate_type="silu"): + super(Encoder_Block, self).__init__() + self.layers = [] + + begin_channel = begin_channel + target_channel = latent_dim + + if activate_type.lower() == "silu": + activate = nn.SiLU() + elif activate_type.lower() == "relu": + activate = nn.ReLU() + elif activate_type.lower() == "gelu": + activate = nn.GELU() + elif activate_type.lower() == "mish": + activate = nn.Mish() + + self.layers.append(nn.Conv1d(begin_channel, target_channel, 3, 2, 1, bias=bias)) + self.layers.append(activate) + + for _ in range(num_layers): ### 196 -> 98 -> 49 -> 24 -> 12 -> 6 -> 3 + self.layers.append(nn.Conv1d(target_channel, target_channel, 3, 2, 1, bias=bias)) + self.layers.append(activate) + self.layers.append(ResConv1DBlock(target_channel, target_channel, bias, norm_type, activate_type)) + + self.layers = nn.Sequential(*self.layers) + self.maxpool = nn.AdaptiveMaxPool1d(TN) + + def forward(self, x): + bs, njoints, nfeats, nframes = x.shape + reshaped_x = x.reshape(bs, njoints * nfeats, nframes) ### [bs, 263, seq] + + res1 = self.layers(reshaped_x) #### [bs, 512, 1] + res2 = self.maxpool(res1) + + res3 = res2.permute(2, 0, 1) + return res3 \ No newline at end of file diff --git a/motion/model/__pycache__/Encode_Full.cpython-310.pyc b/motion/model/__pycache__/Encode_Full.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f90cffc4903d1ce2fbecd3713046afe1001ff74a Binary files /dev/null and b/motion/model/__pycache__/Encode_Full.cpython-310.pyc differ diff --git a/motion/model/__pycache__/Encode_Full.cpython-311.pyc b/motion/model/__pycache__/Encode_Full.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..212e1fdab267539832fa520bdb90c1ae3179e87a Binary files /dev/null and b/motion/model/__pycache__/Encode_Full.cpython-311.pyc differ diff --git a/motion/model/__pycache__/Encode_Full.cpython-39.pyc b/motion/model/__pycache__/Encode_Full.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec0ad62293888fa13bb2d35847a58b909eb4943a Binary files /dev/null and b/motion/model/__pycache__/Encode_Full.cpython-39.pyc differ diff --git a/motion/model/__pycache__/base_transformer.cpython-310.pyc b/motion/model/__pycache__/base_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ab51d48b7c9441738ac4adc7b219a25fc842cb4 Binary files /dev/null and b/motion/model/__pycache__/base_transformer.cpython-310.pyc differ diff --git a/motion/model/__pycache__/base_transformer.cpython-311.pyc b/motion/model/__pycache__/base_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..927cc47e091363f933e6cca3f7a20479187b96a6 Binary files /dev/null and b/motion/model/__pycache__/base_transformer.cpython-311.pyc differ diff --git a/motion/model/__pycache__/base_transformer.cpython-39.pyc b/motion/model/__pycache__/base_transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3ba15fe7e2bd2185a1256bc90ac9852fb4ce61b Binary files /dev/null and b/motion/model/__pycache__/base_transformer.cpython-39.pyc differ diff --git a/motion/model/__pycache__/cfg_sampler.cpython-310.pyc b/motion/model/__pycache__/cfg_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b59f9c29b2bd323fa5ff8247ac03470fb2febe9 Binary files /dev/null and b/motion/model/__pycache__/cfg_sampler.cpython-310.pyc differ diff --git a/motion/model/__pycache__/cfg_sampler.cpython-311.pyc b/motion/model/__pycache__/cfg_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccddc06d982b1df3e8f782e19396a4a6f193f82c Binary files /dev/null and b/motion/model/__pycache__/cfg_sampler.cpython-311.pyc differ diff --git a/motion/model/__pycache__/cfg_sampler.cpython-39.pyc b/motion/model/__pycache__/cfg_sampler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef0469e0c5c7146abfc9e9ed650c08b87d440dae Binary files /dev/null and b/motion/model/__pycache__/cfg_sampler.cpython-39.pyc differ diff --git a/motion/model/__pycache__/layer_norm_fp16.cpython-310.pyc b/motion/model/__pycache__/layer_norm_fp16.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a65811ab7eab0f44f77fe4834ac271ccff94c965 Binary files /dev/null and b/motion/model/__pycache__/layer_norm_fp16.cpython-310.pyc differ diff --git a/motion/model/__pycache__/layer_norm_fp16.cpython-311.pyc b/motion/model/__pycache__/layer_norm_fp16.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4314ad1adaee4ce3a49cf06beb7d410bfdd75fa Binary files /dev/null and b/motion/model/__pycache__/layer_norm_fp16.cpython-311.pyc differ diff --git a/motion/model/__pycache__/layer_norm_fp16.cpython-39.pyc b/motion/model/__pycache__/layer_norm_fp16.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc07a974938fce273f130c490e056caa4f8246e Binary files /dev/null and b/motion/model/__pycache__/layer_norm_fp16.cpython-39.pyc differ diff --git a/motion/model/__pycache__/mdm.cpython-310.pyc b/motion/model/__pycache__/mdm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5db91794359332a265fa72340a63bd2d6d9e2ed1 Binary files /dev/null and b/motion/model/__pycache__/mdm.cpython-310.pyc differ diff --git a/motion/model/__pycache__/mdm.cpython-311.pyc b/motion/model/__pycache__/mdm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96d175cf9457fcb5e36c6ff80032a3aeb4073664 Binary files /dev/null and b/motion/model/__pycache__/mdm.cpython-311.pyc differ diff --git a/motion/model/__pycache__/mdm.cpython-39.pyc b/motion/model/__pycache__/mdm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69b1a13f62b601f4f605ebfa54606f7a7f044471 Binary files /dev/null and b/motion/model/__pycache__/mdm.cpython-39.pyc differ diff --git a/motion/model/base_transformer.py b/motion/model/base_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..06bb1cedf6fa03121cc4b723495dca0ec8060ba0 --- /dev/null +++ b/motion/model/base_transformer.py @@ -0,0 +1,130 @@ +import torch +from torch import nn +import torch.nn.functional as F +import copy +from torch.nn import MultiheadAttention +from motion.model.layer_norm_fp16 import LayerNorm, RMSNorm +import numpy as np +import math + +class SwiGLU(nn.Module): + ''' + follow the structure of llama + ''' + def __init__(self, dim, hidden_dim, multiple_of = 256): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias= False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + +def _get_activation_fn(activation: str): + if activation.lower() == "relu": + return F.relu + elif activation.lower() == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + +class RefinedLayer(nn.Module): + __constants__ = ['batch_first', 'norm_first'] + + def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1, + activation = F.relu, layer_norm_eps = 1e-5, device=None, dtype=None, max_seq_len=196, position_type="static", word_tokens=False, norm_type="rmsnorm", attention_type="torch"): + factory_kwargs = {'device': device, 'dtype': dtype, "bias":False} + super().__init__() + if norm_type.lower() == "rmsnorm": + Norm = RMSNorm + elif norm_type.lower() == "layer": + Norm = LayerNorm + + self.attention_type = attention_type + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs) + + if word_tokens: + self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs) + self.norm3 = Norm(d_model, layer_norm_eps) + self.dropout3 = nn.Dropout(dropout) + self.word_tokens = word_tokens + # Implementation of Feedforward model + + self.norm1 = Norm(d_model, layer_norm_eps) + self.norm2 = Norm(d_model, layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str) and activation.lower() != "swiglu": + activation = _get_activation_fn(activation) + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + self.ffn = self._ff_block + elif activation.lower() == "swiglu": + self.ffn = SwiGLU(d_model, dim_feedforward) + + self.activation = activation + + def forward( + self, + src, + word_tokens = None, + src_mask = None, + src_key_padding_mask = None): + x = src + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + if self.word_tokens: + x = x + self._csa_block(self.norm3(x), word_tokens) + x = x + self.dropout2(self.ffn(self.norm2(x))) + return x + + # encoder block + def _sa_block(self, x, attn_mask, key_padding_mask): + x = self.self_attn(x, x, x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False)[0] + + + return self.dropout1(x) + + # multihead attention block + def _csa_block(self, x, mem, attn_mask=None, key_padding_mask=None): + x = self.cross_attn(x, mem, mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False)[0] + + + return self.dropout3(x) + + # feed forward block + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return x + +class Refined_Transformer(nn.Module): + def __init__(self, refined_layer, num_layers): + super().__init__() + self.layers = _get_clones(refined_layer, num_layers) + self.num_layers = num_layers + + def forward( + self, + src, + word_tokens=None, + src_mask=None, + src_key_padding_mask = None): + output = src + src_key_padding_mask_for_layers = src_key_padding_mask + for mod in self.layers: + output = mod(output, word_tokens=word_tokens, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask_for_layers) + return output diff --git a/motion/model/cfg_sampler.py b/motion/model/cfg_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..49b69f13c25ee9fde4d37b10697e143f373aacde --- /dev/null +++ b/motion/model/cfg_sampler.py @@ -0,0 +1,41 @@ +import numpy as np +import torch +import torch.nn as nn +from copy import deepcopy + +# A wrapper model for Classifier-free guidance **SAMPLING** only +# https://arxiv.org/abs/2207.12598 +class ClassifierFreeSampleModel(nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model # model is the actual model to run + + # assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions' + + # pointers to inner model + self.njoints = self.model.njoints + self.nfeats = self.model.nfeats + self.cond_mode = self.model.cond_mode + + def forward(self, x, timesteps, y=None): + cond_mode = self.model.cond_mode + assert cond_mode in ['text', 'action', "motion", "text-motion"] + y_uncond = deepcopy(y) + y_uncond['uncond'] = True + + out = self.model(x, timesteps, y) ###### 全部条件生成 + + if "predict_length" in out.keys(): + y_uncond["predict_mask"] = out["predict_length"] + + out_uncond = self.model(x, timesteps, y_uncond) ####### 全部无条件 + + output = {} + + y['scale'] = y['scale'].to(out_uncond["output"].device) + + output["output"] = out_uncond["output"] + (y['scale'].view(-1, 1, 1, 1) * (out["output"] - out_uncond["output"])) + + return output ##### 这里并不是生成 \epsilon,而是特征 + diff --git a/motion/model/clip/__init__.py b/motion/model/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/motion/model/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/motion/model/clip/__pycache__/__init__.cpython-310.pyc b/motion/model/clip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a0baea6d3208d89c83c68dc1ca38394f00b0e7a Binary files /dev/null and b/motion/model/clip/__pycache__/__init__.cpython-310.pyc differ diff --git a/motion/model/clip/__pycache__/__init__.cpython-311.pyc b/motion/model/clip/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea66b3fc2832f4540956dcf4ec8b18f966890f78 Binary files /dev/null and b/motion/model/clip/__pycache__/__init__.cpython-311.pyc differ diff --git a/motion/model/clip/__pycache__/__init__.cpython-39.pyc b/motion/model/clip/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd8cdb518d67e2f6a811d7a4c59d845ee2cbc683 Binary files /dev/null and b/motion/model/clip/__pycache__/__init__.cpython-39.pyc differ diff --git a/motion/model/clip/__pycache__/clip.cpython-310.pyc b/motion/model/clip/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b386eeceb198195393e9afca2e30df636a1727c Binary files /dev/null and b/motion/model/clip/__pycache__/clip.cpython-310.pyc differ diff --git a/motion/model/clip/__pycache__/clip.cpython-311.pyc b/motion/model/clip/__pycache__/clip.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f20f70d35021ec2173903e39459b256f12137d64 Binary files /dev/null and b/motion/model/clip/__pycache__/clip.cpython-311.pyc differ diff --git a/motion/model/clip/__pycache__/clip.cpython-39.pyc b/motion/model/clip/__pycache__/clip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b37893c7bb15da0e648e14829f48d65109041aa Binary files /dev/null and b/motion/model/clip/__pycache__/clip.cpython-39.pyc differ diff --git a/motion/model/clip/__pycache__/model.cpython-310.pyc b/motion/model/clip/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..badbee9aee791fbc00dd9ef2bcd3a3472aaf9e0e Binary files /dev/null and b/motion/model/clip/__pycache__/model.cpython-310.pyc differ diff --git a/motion/model/clip/__pycache__/model.cpython-311.pyc b/motion/model/clip/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..469f626e1932be95138601e2f0232af6b8796eb4 Binary files /dev/null and b/motion/model/clip/__pycache__/model.cpython-311.pyc differ diff --git a/motion/model/clip/__pycache__/model.cpython-39.pyc b/motion/model/clip/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..284deda4cf8f08f09d052a9cb0c546b6ba1aa4e6 Binary files /dev/null and b/motion/model/clip/__pycache__/model.cpython-39.pyc differ diff --git a/motion/model/clip/__pycache__/simple_tokenizer.cpython-310.pyc b/motion/model/clip/__pycache__/simple_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9279abf45407ee2ab1c8f57808f167e6d9bab0d7 Binary files /dev/null and b/motion/model/clip/__pycache__/simple_tokenizer.cpython-310.pyc differ diff --git a/motion/model/clip/__pycache__/simple_tokenizer.cpython-311.pyc b/motion/model/clip/__pycache__/simple_tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..334ddea60f1bbfaf9dd13f4ab76ee6251c4dcb4f Binary files /dev/null and b/motion/model/clip/__pycache__/simple_tokenizer.cpython-311.pyc differ diff --git a/motion/model/clip/__pycache__/simple_tokenizer.cpython-39.pyc b/motion/model/clip/__pycache__/simple_tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84c431617259bd75b9542d4aec4f7607ee534234 Binary files /dev/null and b/motion/model/clip/__pycache__/simple_tokenizer.cpython-39.pyc differ diff --git a/motion/model/clip/bpe_simple_vocab_16e6.txt.gz b/motion/model/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/motion/model/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/motion/model/clip/clip.py b/motion/model/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a5da5e69e0a3b41383734711ccfff1923a9ef9 --- /dev/null +++ b/motion/model/clip/clip.py @@ -0,0 +1,245 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), "value") == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/motion/model/clip/model.py b/motion/model/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8dcb623ac5a1df45d613229533c1cb1fbaed03 --- /dev/null +++ b/motion/model/clip/model.py @@ -0,0 +1,426 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from motion.model.layer_norm_fp16 import LayerNorm + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD [global, w1, w2, ....] + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/motion/model/clip/simple_tokenizer.py b/motion/model/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/motion/model/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/motion/model/layer_norm_fp16.py b/motion/model/layer_norm_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9afec4cc3924d7ae47f62d375b067330a16dd1 --- /dev/null +++ b/motion/model/layer_norm_fp16.py @@ -0,0 +1,38 @@ +from torch import nn +import torch +import torch.nn.functional as F + +class LayerNorm(nn.Module): + def __init__(self, normalized_shape, eps = 1e-5, elementwise_affine = True, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] + self.normalized_shape = normalized_shape # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.parameter.Parameter(torch.ones(self.normalized_shape, **factory_kwargs)) + self.bias = nn.parameter.Parameter(torch.zeros(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + def forward(self, input): + orig_type = input.dtype + ret = F.layer_norm(input.type(torch.float32), self.normalized_shape, self.weight.type(torch.float32), self.bias.type(torch.float32), self.eps) + return ret.type(orig_type) + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight \ No newline at end of file diff --git a/motion/model/mdm.py b/motion/model/mdm.py new file mode 100644 index 0000000000000000000000000000000000000000..03f7ffa49124d01144b2545c5de7028ab3b93a89 --- /dev/null +++ b/motion/model/mdm.py @@ -0,0 +1,387 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from motion.model import clip +import json +from motion.model.base_transformer import RefinedLayer, Refined_Transformer +from motion.model.Encode_Full import Encoder_Block + +class MDM(nn.Module): + def __init__(self, njoints, nfeats, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, + activation="gelu", dataset='amass', clip_dim=512, + arch='trans_enc', clip_version=None, **kargs): + super().__init__() + + self.encode_full = kargs.get("encode_full", 0) #### encode_full = 1 add tokens & encode_full = 2 model compress tokens + self.txt_tokens = kargs.get("txt_tokens", 0) #### txt_tokens = 1 add tokens & txt_tokens = 2 model compress tokens + self.frame_mask = kargs.get("frame_mask", 0) + self.dataset = dataset + self.condition_length = 77 + self.num_frames = kargs.get("num_frames", 196) + self.position_type = "static" #### static or rope only for llama arch + self.json_dict = kargs.get("json_dict") + + if isinstance(self.num_frames, list) or isinstance(self.num_frames, tuple): + self.num_frames = self.num_frames[0] + + self.njoints = njoints + self.nfeats = nfeats + + self.latent_dim = latent_dim + + self.ff_size = ff_size + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout = dropout + + self.activation = activation + self.clip_dim = clip_dim + self.action_emb = kargs.get('action_emb', None) + + self.input_feats = self.njoints * self.nfeats + + self.cond_mode = kargs.get('cond_mode', 'no_cond') + self.cond_mask_prob = kargs.get('cond_mask_prob', 0.) + self.arch = arch + + self.input_process = InputProcess(self.input_feats, self.latent_dim) #### 输入 x 的 linear + self.output_process = OutputProcess(self.input_feats, self.latent_dim, self.njoints, + self.nfeats) + + self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) + + if self.arch == 'trans_enc': + print("TRANS_ENC init") + seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation) + self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) + + elif self.arch == "refined_encoder": + TransLayer = RefinedLayer(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation, max_seq_len=self.num_frames, norm_type="rmsnorm") + self.seqTransEncoder = Refined_Transformer(TransLayer, self.num_layers) + + elif self.arch == "refined_decoder": + TransLayer = RefinedLayer(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation, max_seq_len=self.num_frames, word_tokens=True, norm_type="rmsnorm") + self.seqTransEncoder = Refined_Transformer(TransLayer, self.num_layers) + + elif self.arch == "llama_encoder": + TransLayer = RefinedLayer(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation, max_seq_len=self.num_frames, position_type=self.position_type, norm_type="rmsnorm", attention_type="llama") + self.seqTransEncoder = Refined_Transformer(TransLayer, self.num_layers) + + elif self.arch == "llama_decoder": + TransLayer = RefinedLayer(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation, max_seq_len=self.num_frames, position_type=self.position_type, word_tokens=True, norm_type="rmsnorm", attention_type="llama") + self.seqTransEncoder = Refined_Transformer(TransLayer, self.num_layers) + + else: + raise ValueError('Please choose correct architecture') + + self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) + + if self.cond_mode != 'no_cond': + if 'text' in self.cond_mode: + self.embed_text = nn.Linear(self.clip_dim, self.latent_dim) + print('EMBED TEXT') + print('Loading CLIP...') + self.clip_version = clip_version + self.clip_model = self.load_and_freeze_clip(clip_version) + + if self.txt_tokens == 2: + if self.arch in ["refined_encoder", "trans_enc", "llama_encoder"]: + scale = 3 + elif self.arch in ["refined_decoder", "llama_decoder"]: + scale = 2 + encode_compress_layer = RefinedLayer(d_model=self.latent_dim * scale, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation) + self.condition_compress = nn.Sequential( + Refined_Transformer(encode_compress_layer, num_layers=1), + nn.Linear(self.latent_dim * scale, self.latent_dim, ) + ) + + if self.encode_full != 0: #### [1, bs, 512] -> [seq, bs, 1024] -> [seq, bs, 512] + self.code_full = Encoder_Block(begin_channel=self.input_feats, latent_dim=self.latent_dim, num_layers=6, TN=1) + + if self.encode_full == 2: + encode_compress_layer = RefinedLayer(d_model=self.latent_dim * 2, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation) + + self.encode_compress = nn.Sequential( + Refined_Transformer(encode_compress_layer, num_layers=1), + nn.Linear(self.latent_dim * 2, self.latent_dim, ) + ) + + print(" =========================", self.cond_mode, "===================================") + + def parameters_wo_clip(self): + return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] + + def load_and_freeze_clip(self, clip_version): + clip_model, clip_preprocess = clip.load(clip_version, device='cpu', jit=False, download_root=self.json_dict["clip"]) # Must set jit=False for training + clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16 + # clip_model.float() + # Freeze CLIP weights + clip_model.eval() + for p in clip_model.parameters(): + p.requires_grad = False + + return clip_model + + def mask_cond(self, cond, force_mask=False): + bs = cond.shape[0] + if force_mask: + return torch.zeros_like(cond) + elif self.training and self.cond_mask_prob > 0.: + mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob) # 1-> use null_cond, 0-> use real cond + if len(cond.shape) == 3: + mask = mask.view(bs, 1, 1) + else: + mask = mask.view(bs, 1) + return cond * (1. - mask) + else: + return cond + + def mask_motion(self, motion): + # x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper + + if self.training and self.frame_mask > 0.: + pair_motion = torch.randperm(motion.shape[0]) + pair_motion = motion[pair_motion] + if len(motion.shape) == 4: + bs, njoints, nfeats, nframes = motion.shape + mask = torch.bernoulli(torch.ones([bs, 1, 1, nframes], device=motion.device) * self.frame_mask) # 1-> use null_cond, 0-> use real cond + mask = mask.repeat(1, njoints, nfeats, 1) + elif len(motion.shape) == 3: + seqlen, bs, latent_dim = motion.shape + mask = torch.bernoulli(torch.ones([seqlen, bs, 1], device=motion.device) * self.frame_mask) + mask = mask.repeat(1, 1, latent_dim) + return motion * (1. - mask) + pair_motion * mask + else: + return motion + + def clip_text_embedding(self, raw_text): + device = self.clip_model.ln_final.weight.device + default_context_length = self.condition_length + texts = clip.tokenize(raw_text, context_length=default_context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate + if self.txt_tokens == 0: + clip_feature = self.clip_model.encode_text(texts) + else: + with torch.no_grad(): + x = self.clip_model.token_embedding(texts).type(self.clip_model.dtype) # [batch_size, n_ctx, d_model] + x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.clip_model.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.clip_model.ln_final(x).type(self.clip_model.dtype) + clip_feature = x[torch.arange(x.shape[0]), texts.argmax(dim=-1)] @ self.clip_model.text_projection + clip_feature = clip_feature.unsqueeze(1) + clip_feature = torch.cat([clip_feature, x], dim=1) #### [bs, T, 512] + return clip_feature + + def get_mask(self, sz1, sz2): + mask = (torch.triu(torch.ones(sz1, sz2)) == 1).transpose(0, 1) + mask = mask.float() + mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + mask.requires_grad = False + return mask + + def forward(self, x, timesteps, y=None): + """ + x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper + timesteps: [batch_size] (int) + """ + + results = {} + emb = self.embed_timestep(timesteps) # [1, bs, d] + x = x.to(emb.dtype) + + x = self.mask_motion(x) + + real_length = x.shape[-1] + if self.encode_full != 0 and x.shape[-1] < self.num_frames: + extension = torch.zeros([x.shape[0], x.shape[1], x.shape[2], self.num_frames - x.shape[-1]], device=x.device, dtype=x.dtype) + x = torch.cat([x, extension], dim=-1) + + if self.encode_full == 1: + latent = self.code_full(x) ### [seq, bs, 512] + current = self.input_process(x) + latent = latent.repeat(current.shape[0], 1, 1) + current = current + latent + elif self.encode_full == 2: + latent = self.code_full(x) ### [seq, bs, 512] + current = self.input_process(x) #### [seq, bs, 512] + latent = latent.repeat(current.shape[0], 1, 1) + current = torch.cat([current, latent], dim=2) + current = self.encode_compress(current) + else: + current = self.input_process(x) #### [seq, bs, 512] + + force_mask = y.get('uncond', False) + if 'text' in self.cond_mode: + enc_text = self.clip_text_embedding(y['text']).to(emb.dtype) ### MASK_COND 会按照一定的比例把 batch_size 中的一部分文本句整句换成 [0, 0, ... 0] + txt_emb = self.embed_text(enc_text) + txt_emb = self.mask_cond(txt_emb, force_mask=force_mask) + + if len(txt_emb.shape) == 3: + txt_emb = txt_emb.permute(1, 0, 2) + else: + txt_emb = txt_emb.unsqueeze(0) + else: + txt_emb = None + + if txt_emb is not None: + all_emb = txt_emb + else: + all_emb = torch.zeros_like(emb) + + if self.arch in ["refined_encoder", "trans_enc", "llama_encoder"] and txt_emb is not None: + if self.txt_tokens == 1: + word_embedding = all_emb[1::, :, :] + global_embedding = all_emb[0:1, :, :].repeat(word_embedding.shape[0], 1, 1) + all_emb = word_embedding + global_embedding + emb = emb.repeat(all_emb.shape[0], 1, 1) + emb += all_emb + elif self.txt_tokens == 2: + word_embedding = all_emb[1::, :, :] + global_embedding = all_emb[0:1, :, :].repeat(word_embedding.shape[0], 1, 1) + emb = emb.repeat(word_embedding.shape[0], 1, 1) + concat_embedding = torch.cat([emb, global_embedding, word_embedding], dim=2) + emb = self.condition_compress(concat_embedding) + else: + emb += all_emb + elif txt_emb is not None: + if self.txt_tokens == 1: + emb = emb.repeat(all_emb.shape[0], 1, 1) + emb += all_emb + elif self.txt_tokens == 2: + emb = emb.repeat(all_emb.shape[0], 1, 1) + concat_embedding = torch.cat([emb, all_emb], dim=2) + emb = self.condition_compress(concat_embedding) + else: + emb += all_emb + else: + emb = emb.repeat(all_emb.shape[0], 1, 1) + emb += all_emb + + if self.arch in ["trans_enc", "refined_encoder", "llama_encoder"]: + real_token_length = emb.shape[0] ######### 用来截断输出,只保留真正的output + elif self.arch in ["refined_decoder", "llama_decoder"]: + real_token_length = 1 + + if self.arch in ["trans_enc", "refined_encoder", "llama_encoder"]: + xseq = torch.cat([emb, current], dim=0) + + if self.arch in ["trans_enc", "refined_encoder"] or self.position_type == "static": + xseq = self.sequence_pos_encoder(xseq) + + output = self.seqTransEncoder(xseq) + elif self.arch in ["refined_decoder", "llama_decoder"]: + xseq = torch.cat([emb[0:1], current], dim=0) + word_tokens = emb[1::] + + if self.arch in ["refined_decoder"] or self.position_type == "static": + xseq = self.sequence_pos_encoder(xseq) + # word_tokens = self.sequence_pos_encoder(word_tokens) + + output = self.seqTransEncoder(xseq, word_tokens=word_tokens) + + output = output[real_token_length:] + output = self.output_process(output) # [bs, njoints, nfeats, nframes] + output = output[:, :, :, :real_length] + results["output"] = output + return results + + def _apply(self, fn): + super()._apply(fn) + + def train(self, *args, **kwargs): + super().train(*args, **kwargs) + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) ###### max_len 是 T_steps 长度, d_model 是嵌入特征的维度 + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_parameter('pe', nn.Parameter(pe, requires_grad=False)) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + +class TimestepEmbedder(nn.Module): + def __init__(self, latent_dim, sequence_pos_encoder): + super().__init__() + self.latent_dim = latent_dim + self.sequence_pos_encoder = sequence_pos_encoder + + time_embed_dim = self.latent_dim + self.time_embed = nn.Sequential( + nn.Linear(self.latent_dim, time_embed_dim, ), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim, ), + ) + + def forward(self, timesteps): #### timesteps 也是按照 position 的方式编码的 [times, 1, latent] -> [1, times, latent] ? + return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) + + +class InputProcess(nn.Module): + def __init__(self, input_feats, latent_dim): + super().__init__() + self.input_feats = input_feats + self.latent_dim = latent_dim + self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) + + def forward(self, x): + bs, njoints, nfeats, nframes = x.shape ### [B,263, nframes] -> [B, nframes, 263] + x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) + x = self.poseEmbedding(x) # [seqlen, bs, d] + return x + + +class OutputProcess(nn.Module): + def __init__(self, input_feats, latent_dim, njoints, nfeats): + super().__init__() + self.input_feats = input_feats + self.latent_dim = latent_dim + self.njoints = njoints + self.nfeats = nfeats + + self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) + + + def forward(self, output): + nframes, bs, d = output.shape + output = self.poseFinal(output) # [seqlen, bs, 150] + output = output.reshape(nframes, bs, self.njoints, self.nfeats) + output = output.permute(1, 2, 3, 0) # [bs, njoints, nfeats, nframes] + + return output + + +class EmbedAction(nn.Module): + def __init__(self, num_actions, latent_dim): + super().__init__() + self.action_embedding = nn.Parameter(torch.randn(num_actions, latent_dim)) + + def forward(self, input): + idx = input[:, 0].to(torch.long) # an index array must be long + output = self.action_embedding[idx] + return output \ No newline at end of file diff --git a/motion/model_util.py b/motion/model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b11b0389236a05cbe1e36d8fad016e2e60ba5b87 --- /dev/null +++ b/motion/model_util.py @@ -0,0 +1,101 @@ +from motion.model.mdm import MDM +from motion.diffusion import gaussian_diffusion as gd +from motion.diffusion.respace import SpacedDiffusion, space_timesteps, InpaintingGaussianDiffusion + +def load_model_wo_clip(model, state_dict): + print("load model checkpoints without clip") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + print(unexpected_keys) + assert all([k.startswith('clip_model.') for k in missing_keys]) + +def load_ft_model_wo_clip(model, state_dict): + print("load model checkpoints without clip") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print(unexpected_keys) + + # for name, value in model.named_parameters(): + # if "seqTransEncoder" in name and "self_attn" in name: + # value.requires_grad = False + # if name.startswith("code_full") or name.startswith("encode_compress") or name.startswith("input_process"): + # value.requires_grad = False + + assert all([k.startswith('clip_pose_encoder.') for k in unexpected_keys]) + # assert all([k.startswith('clip_model.') or k.startswith('clip_pose_encoder.') or k.startswith('embed_text.') for k in missing_keys]) + +def create_model_and_diffusion(args, mode="text", json_dict=None): + model = MDM(**get_model_args(args), json_dict=json_dict) + diffusion = create_gaussian_diffusion(args, mode) + return model, diffusion + +def get_model_args(args): + # default args + clip_version = 'ViT-B/32' + if args.unconstrained: + cond_mode = 'no_cond' + elif args.dataset in ['kit', 'humanml']: + cond_mode = "text" + + if args.arch in ["refined_encoder", "refined_decoder"]: + activation = "swiglu" + else: + activation = "gelu" + + if args.dataset == 'humanml': + njoints = 263 + nfeats = 1 + elif args.dataset == 'kit': + njoints = 251 + nfeats = 1 + + if args.rep == "smr": + njoints += 6 + nfeats = 1 + + return {'njoints': njoints, 'nfeats': nfeats, 'latent_dim': args.latent_dim, 'ff_size': args.ff_size, 'num_layers': args.layers, 'num_heads': args.heads, + 'dropout': 0.1, 'activation': activation, 'cond_mode': cond_mode, 'cond_mask_prob': args.cond_mask_prob, 'arch': args.arch, + 'clip_version': clip_version, 'dataset': args.dataset, "local":args.local, "encode_full":args.encode_full, "txt_tokens":args.txt_tokens, + "num_frames":args.num_frames, "frame_mask":args.frame_mask} + + +def create_gaussian_diffusion(args, mode="text"): + # default params + predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! + steps = 1000 + scale_beta = 1. # no scaling + timestep_respacing = '' # can be used for ddim sampling, we don't use it. + learn_sigma = False + rescale_timesteps = False + + betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) + loss_type = gd.LossType.MSE + + if not timestep_respacing: + timestep_respacing = [steps] + + if mode is not None and (mode.startswith("finetune_control") or mode == "control_length"): + print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> inpainting diffusion model") + diffusion = InpaintingGaussianDiffusion + else: + print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> SpacedDiffusion") + diffusion = SpacedDiffusion + + return diffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not args.sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + rep=args.rep + ) \ No newline at end of file diff --git a/motion/path.json b/motion/path.json new file mode 100644 index 0000000000000000000000000000000000000000..49554664b64aa8ae9db39053f49ad88078927143 --- /dev/null +++ b/motion/path.json @@ -0,0 +1,11 @@ +{ + "smpl_path":"body_models", + "cadm":"weights/refined_decoder_norm.pth", + "cadm-augment":"weights/refined_decoder_augment.pth", + "mdm":"weights/given_mdm.pt", + "dataset_dir":"motion/dataset", + "joints2smpl":"SMPLX/visualize_joint2smpl/joints2smpl/smpl_models", + "clip":"weights/", + "config":"motion/sample_config.json", + "tada_base":"tada-extend" +} \ No newline at end of file diff --git a/motion/plot3d.py b/motion/plot3d.py new file mode 100644 index 0000000000000000000000000000000000000000..9d56c1d39d79afa4ac7acbca8f56549567fc0c75 --- /dev/null +++ b/motion/plot3d.py @@ -0,0 +1,118 @@ +import math +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import mpl_toolkits.mplot3d.axes3d as p3 +from textwrap import wrap +from tqdm import tqdm + +def list_cut_average(ll, intervals): + if intervals == 1: + return ll + + bins = math.ceil(len(ll) * 1.0 / intervals) + ll_new = [] + for i in range(bins): + l_low = intervals * i + l_high = l_low + intervals + l_high = l_high if l_high < len(ll) else len(ll) + ll_new.append(np.mean(ll[l_low:l_high])) + return ll_new + + +def plot_3d_motion(kinematic_tree, joints, title, dataset="humanml", figsize=(10.24, 10.24), radius=3, + vis_mode='default', gt_frames=[]): + matplotlib.use('Agg') + title = '\n'.join(wrap(title, 40)) + + def init(): + ax.set_xlim3d([-radius / 2, radius / 2]) + ax.set_ylim3d([0, radius]) + ax.set_zlim3d([-radius / 3., radius * 2 / 3.]) + # print(title) + fig.suptitle(title, fontsize=20) + ax.grid(b=False) + + def plot_xzPlane(minx, maxx, miny, minz, maxz): + ## Plot a plane XZ + verts = [ + [minx, miny, minz], + [minx, miny, maxz], + [maxx, miny, maxz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + + # return ax + + # (seq_len, joints_num, 3) + data = joints.copy().reshape(len(joints), -1, 3) + + # preparation related to specific datasets + if dataset == 'kit': + data *= 0.003 # scale for visualization + elif dataset == 'humanml': + data *= 1.3 # scale for visualization + elif dataset in ['humanact12', 'uestc']: + data *= -1.5 # reverse axes, scale for visualization + + fig = plt.figure(figsize=figsize) + plt.tight_layout() + ax = p3.Axes3D(fig) + init() + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + colors_blue = ["#4D84AA", "#5B9965", "#61CEB9", "#34C1E2", "#80B79A"] # GT color + colors_orange = ["#DD5A37", "#D69E00", "#B75A39", "#FF6D00", "#DDB50E"] # Generation color + colors = colors_orange + if vis_mode == 'upper_body': # lower body taken fixed to input motion + colors[0] = colors_blue[0] + colors[1] = colors_blue[1] + elif vis_mode == 'gt': + colors = colors_blue + + frame_number = data.shape[0] + # print(dataset.shape) + + height_offset = MINS[1] + data[:, :, 1] -= height_offset + trajec = data[:, 0, [0, 2]] + + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + # print(trajec.shape) + + def update(index): + # print(index) + ax.lines = [] + ax.collections = [] + ax.view_init(elev=120, azim=-90) + ax.dist = 7.5 + # ax = + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], + MAXS[2] - trajec[index, 1]) + + used_colors = colors_blue if index in gt_frames else colors + for i, (chain, color) in enumerate(zip(kinematic_tree, used_colors)): + if i < 5: + linewidth = 4.0 + else: + linewidth = 2.0 + ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, + color=color) + # print(trajec[:index, 0].shape) + + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + for i in tqdm(range(frame_number)): + update(i) + plt.savefig("temp/%06d.png"%(i)) + + plt.close() \ No newline at end of file diff --git a/motion/sample.py b/motion/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd904951e343553d66cc8fce72afb96bbf0d389 --- /dev/null +++ b/motion/sample.py @@ -0,0 +1,131 @@ +from argparse import Namespace +import torch +from motion.dataset.recover_joints import recover_from_ric +from motion.model.cfg_sampler import ClassifierFreeSampleModel +from motion.model_util import create_model_and_diffusion, load_model_wo_clip +import os +import numpy as np +from motion.dataset.recover_smr import * +import json +from motion.double_take import double_take + +class Predictor(object): + def __init__(self, **kargs): + self.path = kargs["path"] + self.handshake_size = 20 + self.blend_size = 10 + + args = Namespace() + with open(self.path["config"], 'r') as f: + params1 = json.load(f) + for key, value in params1.items(): + setattr(args, key, value) + + + mode = kargs.get("mode", "cadm") + if mode == "cadm": + args.arch = "refined_decoder" + args.encode_full = 2 + args.txt_tokens = 1 + args.model_path = self.path["cadm"] + args.rep = "smr" + elif mode == "cadm-augment": + args.arch = "refined_decoder" + args.encode_full = 2 + args.txt_tokens = 1 + args.model_path = self.path["cadm-augment"] + args.rep = "smr" + elif mode == "mdm": + args.arch = "trans_enc" + args.encode_full = 0 + args.txt_tokens = 0 + args.model_path = self.path["mdm"] + args.rep = "t2m" + + self.skip_steps = kargs.get("skip_steps", 0) + self.device = kargs.get("device", "cpu") + self.args = args + self.rep = args.rep + self.num_frames = args.num_frames + self.condition = kargs.get("condition", "text") + if self.condition == "uncond": + self.args.guidance_param = 0 + + if self.rep == "t2m": + extension = "" + elif self.rep == "smr": + extension = "_smr" + + self.mean = torch.from_numpy(np.load(os.path.join(self.path["dataset_dir"], 'Mean{}.npy'.format(extension)))).to(self.device) + self.std = torch.from_numpy(np.load(os.path.join(self.path["dataset_dir"], 'Std{}.npy'.format(extension)))).to(self.device) + + print(f"Loading checkpoints from...") + self.model, self.diffusion = create_model_and_diffusion(args, args.control_signal, self.path) + state_dict = torch.load(self.args.model_path, map_location='cpu') + try: + if self.args.ema: + print("EMA Checkpoints Loading.") + load_model_wo_clip(self.model, state_dict["ema"]) + else: + print("Normal Checkpoints Loading.") + load_model_wo_clip(self.model, state_dict["model"]) + except: + load_model_wo_clip(self.model, state_dict) + + if self.args.guidance_param != 1 and not self.args.unconstrained: + self.model = ClassifierFreeSampleModel(self.model) # wrapping model with the classifier-free sampler + self.model.to(self.device) + self.model.eval() # disable random masking + + def predict(self,prompt, num_repetitions=1, path=None): + double_split = prompt.split("|") + if len(double_split) > 1: + print("sample mode - double_take long motion") + sample, step_sizes = double_take(prompt, path, num_repetitions, self.model, self.diffusion, self.handshake_size, + self.blend_size, self.num_frames, self.args.guidance_param, self.device) + + sample = sample.permute(0, 2, 3, 1).float() + sample = sample * self.std + self.mean + if self.rep == "t2m": + sample = recover_from_ric(sample, 22) + sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) + elif self.rep == "smr": + sample = sample.permute(0, 2, 3, 1) + else: + nframes = prompt.split(",")[0] + try: + nframes = int(nframes) + prompt = prompt.split(",")[1::] + prompt = ",".join(prompt) + except: + nframes = self.num_frames + + model_kwargs = {'y':{'text': str(prompt), 'lengths':nframes}} + if self.args.guidance_param != 1: + model_kwargs['y']['scale'] = torch.ones(num_repetitions, device=self.device) * self.args.guidance_param + + sample_fn = self.diffusion.p_sample_loop + sample = sample_fn( + self.model, + (num_repetitions, self.model.njoints, self.model.nfeats, nframes), + clip_denoised=False, + model_kwargs=model_kwargs, + skip_timesteps=self.skip_steps, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False + ) + sample = sample["output"] + sample = sample.permute(0, 2, 3, 1).float() + sample = sample * self.std + self.mean + + if self.rep == "t2m": + sample = recover_from_ric(sample, 22) + sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) + elif self.rep == "smr": + sample = sample.permute(0, 2, 3, 1) + + all_motions = sample.permute(0, 3, 1, 2) + return all_motions \ No newline at end of file diff --git a/motion/sample_config.json b/motion/sample_config.json new file mode 100644 index 0000000000000000000000000000000000000000..cee0eb7ec962ba8ccc80f48fafa8795e8cf45500 --- /dev/null +++ b/motion/sample_config.json @@ -0,0 +1,23 @@ +{ + "local":false, + "num_frames": 196, + "ema": true, + "rep": "smr", + "dataset": "humanml", + "model_path": "/apdcephfs/private_kleinhe/mdm/llama_decoder_static_mask25/S100000_F0.2312_T0.5190.pth", + "control_signal": null, + "noise_schedule":"cosine", + "diffusion_steps":1000, + "sigma_small":true, + "layers":8, + "heads":4, + "ff_size":1024, + "latent_dim":512, + "cond_mask_prob":0.1, + "unconstrained":false, + "arch":"refined_decoder", + "encode_full":2, + "txt_tokens":1, + "guidance_param":2.5, + "frame_mask":0 + } \ No newline at end of file diff --git a/motion/visual_api.py b/motion/visual_api.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f0c6e400c0c1b39f78487b2b933af058ff7c0c --- /dev/null +++ b/motion/visual_api.py @@ -0,0 +1,269 @@ +from torch import nn +import torch +import numpy as np +from SMPLX.visualize_joint2smpl.simplify_loc2rot import joints2smpl +from motion.hybrik_loc2rot import HybrIKJointsToRotmat +from SMPLX import smplx +from SMPLX.read_from_npy import npy2info, info2dict +from SMPLX.rotation_conversions import * +from motion.dataset.recover_smr import * +from motion.dataset.recover_joints import recover_from_ric as recover_joints +from motion.dataset.paramUtil import t2m_kinematic_chain +from motion.plot3d import plot_3d_motion +import os +import subprocess +import platform +from PIL import Image +from motion.sample import Predictor as mdm_predictor +from TADA.anime import Animation + +class Visualize(nn.Module): + def __init__(self, **kargs): + super(Visualize, self).__init__() + self.mode = kargs.get("mode", "cadm") + if self.mode in ["mdm", "cadm", "cadm-augment"]: + self.predictor = mdm_predictor(**kargs) + self.rep = self.predictor.rep + self.smpl_path = kargs.get("smpl_path") + self.device = kargs.get("device", "cpu") + self.rotate = kargs.get("rotate", 0) + self.pose_generator = HybrIKJointsToRotmat() + self.path = kargs["path"] + + self.tada_base = kargs.get("tada_base", None) + self.tada_role = kargs.get("tada_role", None) + + if self.tada_base is not None and self.tada_role is not None: + self.anime = Animation(self.tada_role, self.tada_base, self.device) + self.face = None + else: + self.face = np.load(os.path.join(self.path["dataset_dir"], "smplh.faces")) + self.anime = None + + def fit2smpl(self, motion, mode="fast"): + print(">>>>>>>>>>>>>>> fit joints to smpl >>>>>>>>>>>>>>>>>>>>") + if mode == "slow": + frames = motion.shape[0] + j2s = joints2smpl(num_frames=frames, device=self.device, model_path=self.smpl_path, json_dict=self.path) + motion_tensor, translation = j2s.joint2smpl(motion) + else: + translation = motion[:, 0:1, :] - motion[0, 0:1, :] + motion = self.pose_generator(motion) + motion = torch.from_numpy(motion) + hand = torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(motion.shape[0], 2, 1, 1) + motion = torch.cat([motion, hand], dim=1) + motion_tensor = matrix_to_axis_angle(motion) + motion_tensor = motion_tensor.numpy() + + return motion_tensor, translation + + def predict(self, sentence, path, render_mode="pyrender", joint_path=None, smpl_path=None): + if self.mode == "pose": + motion_tensor = np.load(path) + if render_mode == "joints": + _, joints = self.get_mesh(motion_tensor) + motion_tensor = joints + + elif self.mode == "joints": + joints = np.load(path) + if render_mode == "joints": + motion_tensor = joints + else: + motion_tensor, translation = self.fit2smpl(joints, render_mode.split("_")[-1]) + motion_tensor = np.concatenate([motion_tensor, translation], axis=1) + motion_tensor = motion_tensor.reshape(motion_tensor.shape[0], -1) + elif self.mode in ["mdm", "cadm", "cadm-augment"]: + motion_tensor = self.predictor.predict(sentence, 1, path) + if self.rep == "t2m": + motion_tensor = motion_tensor[0].detach().cpu().numpy() #### [nframes, 263] + + if joint_path is not None: + np.save(joint_path, motion_tensor) + + if render_mode == "joints": + motion_tensor = motion_tensor + else: + motion_tensor, translation = self.fit2smpl(motion_tensor, render_mode.split("_")[-1]) + motion_tensor = np.concatenate([motion_tensor, translation], axis=1) + motion_tensor = motion_tensor.reshape(motion_tensor.shape[0], -1) + + if smpl_path is not None: + np.save(smpl_path, motion_tensor) + + elif self.rep == "smr": + motion_tensor = motion_tensor[0][0].detach().cpu().numpy() + joints = recover_from_ric(motion_tensor, 22) + + if joint_path is not None: + np.save(joint_path, joints) + + if render_mode == "joints": + motion_tensor = joints + else: + pose = recover_pose_from_smr(motion_tensor, 22) + pose = pose.reshape(pose.shape[0], -1, 3) + motion_tensor, translation = self.fit2smpl(joints, render_mode.split("_")[-1]) + motion_tensor = np.concatenate([motion_tensor, translation], axis=1) + motion_tensor = motion_tensor.reshape(motion_tensor.shape[0], -1, 3) + replace = [12, 15, 20, 21] + motion_tensor[:, replace, :] = pose[:, replace, :] + motion_tensor = motion_tensor.reshape(motion_tensor.shape[0], -1) + + if smpl_path is not None: + np.save(smpl_path, motion_tensor) + + return motion_tensor.astype(np.float32) + + def joints_process(self, joints, text, width=1024, height=1024): + os.makedirs("temp", exist_ok=True) + plot_3d_motion(t2m_kinematic_chain, joints, text, figsize=(width/100, height/100)) + files = os.listdir("temp") + files = sorted(files) + pics = [] + for i in range(len(files)): + pic = Image.open(os.path.join("temp", files[i])) + pic = np.asarray(pic) + pics.append(pic.copy()) + + cmd = "rm -r temp" + subprocess.call(cmd, shell=platform.system() != 'Windows') + pics = np.stack(pics, axis=0) + return pics + + def pyrender_process(self, vertices, height=1024, weight=1024): + import trimesh + from trimesh import Trimesh + import pyrender + from pyrender.constants import RenderFlags + import os + os.environ['PYOPENGL_PLATFORM'] = "egl" + from shapely import geometry + from tqdm import tqdm + + faces = self.face + + vertices = vertices.astype(np.float32) + MINS = np.min(np.min(vertices, axis=0), axis=0) + MAXS = np.max(np.max(vertices, axis=0), axis=0) + + #################### position initial at zero point + vertices[:, :, 0] -= (MAXS + MINS)[0] / 2 + vertices[:, :, 2] -= (MAXS + MINS)[2] / 2 + + MINS = np.min(np.min(vertices, axis=0), axis=0) + MAXS = np.max(np.max(vertices, axis=0), axis=0) + + pics = [] + + ############### ground initial ########### + minx = MINS[0] - 0.5 + maxx = MAXS[0] + 0.5 + minz = MINS[2] - 0.5 + maxz = MAXS[2] + 0.5 + polygon = geometry.Polygon([[minx, minz], [minx, maxz], [maxx, maxz], [maxx, minz]]) + polygon_mesh = trimesh.creation.extrude_polygon(polygon, 1e-5) + polygon_mesh.visual.face_colors = [0, 0, 0, 0.21] + polygon_render = pyrender.Mesh.from_trimesh(polygon_mesh, smooth=False) + + r = pyrender.OffscreenRenderer(weight, height) + + for i in tqdm(range(vertices.shape[0])): + end_color = np.array([30, 128, 255]) / 255.0 + + bg_color = [1, 1, 1, 0.8] + scene = pyrender.Scene(bg_color=bg_color, ambient_light=(0.4, 0.4, 0.4)) + + if self.anime is None: + mesh = Trimesh(vertices=vertices[i, :, :].tolist(), faces=faces) + base_color = end_color.tolist() + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.7, roughnessFactor=0.7, + alphaMode='OPAQUE', + baseColorFactor=base_color + ) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + else: + mesh = Trimesh(vertices=vertices[i, :, :].tolist(), faces=faces, visual=self.anime.trimesh_visual, process=False) + mesh = pyrender.Mesh.from_trimesh(mesh, smooth=True, material=None) + + scene.add(mesh) + + ########################### ground ################## + c = np.pi / 2 + scene.add(polygon_render, pose=np.array([[ 1, 0, 0, 0], + [ 0, np.cos(c), -np.sin(c), MINS[1]], + [ 0, np.sin(c), np.cos(c), 0], + [ 0, 0, 0, 1]])) + + ################ light ############ + light = pyrender.DirectionalLight(color=[1,1,1], intensity=300) + light_pose = np.eye(4) + light_pose[:3, 3] = [0, -1, 1] + scene.add(light, pose=light_pose.copy()) + light_pose[:3, 3] = [0, 1, 1] + scene.add(light, pose=light_pose.copy()) + light_pose[:3, 3] = [1, 1, 2] + scene.add(light, pose=light_pose.copy()) + + ################ camera ############## + camera = pyrender.PerspectiveCamera(yfov=(np.pi / 3.0)) + c = -np.pi / 6 + scene.add(camera, pose=[[ 1, 0, 0, (minx+maxx)], + [ 0, np.cos(c), -np.sin(c), 2.5], + [ 0, np.sin(c), np.cos(c), max(4, minz+(1.5-MINS[1])*2, (maxx-minx))], + [ 0, 0, 0, 1] + ]) + + pic, _ = r.render(scene, flags=RenderFlags.RGBA) + pics.append(pic) + + pics = np.stack(pics, axis=0) + return pics + + @torch.no_grad() + def get_mesh(self, motions): + if self.anime is not None: + vertices, faces = self.anime.forward_mdm(motions) + joints = vertices + self.face = faces + else: + motions, trans, gender, betas = npy2info(motions, 10) + + betas = None + gender = "neutral" + + if motions.shape[1] == 72: + mode = "smpl" + elif motions.shape[1] == 156: + mode = "smplh" + elif motions.shape[1] == 165: + motions = np.concatenate([motions[:, :66], motions[:, 75::]], axis=1) + mode = "smplh" + + if self.rotate != 0: + motions = motions.reshape(motions.shape[0], -1, 3) + motions = torch.from_numpy(motions).float() + first_frame_root_pose_matrix = axis_angle_to_matrix(motions[0][0]) + all_root_poses_matrix = axis_angle_to_matrix(motions[:, 0, :]) + aligned_root_poses_matrix = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1), + all_root_poses_matrix) + motions[:, 0, :] = matrix_to_axis_angle(aligned_root_poses_matrix) + motions = motions.reshape(motions.shape[0], -1) + motions = motions.numpy() + + print("Visualize Mode -> ", mode) + model = smplx.create(self.smpl_path, model_type=mode, + gender=gender, use_face_contour=True, + num_betas=10, + num_expression_coeffs=10, + ext="npz", use_pca=False, batch_size=motions.shape[0]) + model = model.eval().to(self.device) + + inputs = info2dict(motions, trans, betas, mode, self.device) + + output = model(**inputs) + + vertices = output.vertices.cpu().numpy() + joints = output.joints.cpu().numpy() + + return vertices, joints \ No newline at end of file diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..88e73d28ce9e8393c3e5f83bc927f7b660baee59 --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +freeglut3-dev \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5820aed1c41d439b28ee85be7995ebb16dd578be --- /dev/null +++ b/requirements.txt @@ -0,0 +1,75 @@ +scipy==1.11.3 +chumpy==0.70 +pyrender==0.1.45 +aiofiles==23.2.1 +altair==5.1.2 +annotated-types==0.6.0 +anyio==3.7.1 +attrs==23.1.0 +certifi==2023.7.22 +charset-normalizer==3.3.0 +click==8.1.7 +cycler==0.12.1 +decorator==4.4.2 +fastapi==0.103.2 +ffmpy==0.3.1 +filelock==3.12.4 +fsspec==2023.9.2 +ftfy==6.1.1 +gradio==3.47.1 +gradio_client==0.6.0 +h11==0.14.0 +h5py==3.10.0 +httpcore==0.18.0 +httpx==0.25.0 +huggingface-hub==0.18.0 +idna==3.4 +imageio==2.31.5 +imageio-ffmpeg==0.4.9 +importlib-resources==6.1.0 +Jinja2==3.1.2 +jsonschema==4.19.1 +jsonschema-specifications==2023.7.1 +kiwisolver==1.4.5 +MarkupSafe==2.1.3 +matplotlib==3.1.3 +moviepy==1.0.3 +mpmath==1.3.0 +networkx==3.1 +orjson==3.9.9 +packaging==23.2 +pandas==2.1.1 +Pillow==10.0.1 +proglog==0.1.10 +pydantic==2.4.2 +pydantic_core==2.10.1 +pydub==0.25.1 +pyparsing==3.1.1 +python-dateutil==2.8.2 +python-multipart==0.0.6 +pytz==2023.3.post1 +PyYAML==6.0.1 +referencing==0.30.2 +regex==2023.10.3 +requests==2.31.0 +rpds-py==0.10.6 +semantic-version==2.10.0 +six==1.16.0 +sniffio==1.3.0 +starlette==0.27.0 +sympy==1.12 +toolz==0.12.0 +torch==2.1.0 +torchvision==0.16.0 +tqdm==4.66.1 +triton==2.1.0 +typing_extensions==4.8.0 +tzdata==2023.3 +urllib3==2.0.6 +uvicorn==0.23.2 +wcwidth==0.2.8 +websockets==11.0.3 +numpy==1.23.0 +shapely==2.0.2 +mapbox_earcut==1.0.1 +gdown==4.7.1 \ No newline at end of file diff --git a/scripts/prepare.py b/scripts/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..1bfb8d04d1580ea5e7420c99cefdedee018bba58 --- /dev/null +++ b/scripts/prepare.py @@ -0,0 +1,7 @@ +from huggingface_hub import snapshot_download +def prepare(): + REPO_ID = 'Kleinhe/CAMD' + snapshot_download(repo_id=REPO_ID, local_dir='./', local_dir_use_symlinks=False) + +if __name__ == "__main__": + prepare() \ No newline at end of file diff --git a/scripts/prepare.sh b/scripts/prepare.sh new file mode 100644 index 0000000000000000000000000000000000000000..17451255afbbd5aa9e1aad79ff6961ba081e4bf0 --- /dev/null +++ b/scripts/prepare.sh @@ -0,0 +1,3 @@ +pip install huggingface_hub +git install lfs +python scripts/prepare.py \ No newline at end of file diff --git a/scripts/tada_goole.sh b/scripts/tada_goole.sh new file mode 100644 index 0000000000000000000000000000000000000000..17b51ecbde6ad6f933ee1c0a2b44048c6e0ce688 --- /dev/null +++ b/scripts/tada_goole.sh @@ -0,0 +1,2 @@ +gdown https://drive.google.com/uc?id=1O-2pfMz-6b5fFk2Ju1GsLzsg7HZkTTeQ +unzip tada-extend.zip diff --git a/scripts/tada_process.sh b/scripts/tada_process.sh new file mode 100644 index 0000000000000000000000000000000000000000..af83de33056c73ec1f0155165a25c8fc788d937f --- /dev/null +++ b/scripts/tada_process.sh @@ -0,0 +1,7 @@ +unzip Avatar-100.zip +unzip tada_extra_data.zip +rm Avatar-100.zip +rm tada_extra_data.zip +mv Avatar-100/MESH data/ +rm -r Avatar-100 +mv data tada-extend \ No newline at end of file