Spaces:
Runtime error
Runtime error
File size: 6,061 Bytes
c3d0293 9ae6df4 c3d0293 9ae6df4 c3d0293 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import json
import pickle as pkl
import random
import argparse
import torch
from TADA import smplx
import imageio
import numpy as np
from tqdm import tqdm
from PIL import Image
from TADA.lib.common.remesh import subdivide_inorder
from TADA.lib.common.utils import SMPLXSeg
from TADA.lib.common.lbs import warp_points
from TADA.lib.common.obj import compute_normal
import trimesh
import pyrender
from shapely import geometry
import moviepy.editor as mpy
os.environ['PYOPENGL_PLATFORM'] = "egl"
def build_new_mesh(v, f, vt, ft):
# build a correspondences dictionary from the original mesh indices to the (possibly multiple) texture map indices
f_flat = f.flatten()
ft_flat = ft.flatten()
correspondences = {}
# traverse and find the corresponding indices in f and ft
for i in range(len(f_flat)):
if f_flat[i] not in correspondences:
correspondences[f_flat[i]] = [ft_flat[i]]
else:
if ft_flat[i] not in correspondences[f_flat[i]]:
correspondences[f_flat[i]].append(ft_flat[i])
# build a mesh using the texture map vertices
new_v = np.zeros((v.shape[0], vt.shape[0], 3))
for old_index, new_indices in correspondences.items():
for new_index in new_indices:
new_v[:, new_index] = v[:, old_index]
# define new faces using the texture map faces
f_new = ft
return new_v, f_new
class Animation:
def __init__(self, ckpt_path, workspace_dir, device="cuda"):
self.device = device
self.SMPLXSeg = SMPLXSeg(workspace_dir)
# load data
init_data = np.load(os.path.join(workspace_dir, "init_body/data.npz"))
self.dense_faces = torch.as_tensor(init_data['dense_faces'], device=self.device)
self.dense_lbs_weights = torch.as_tensor(init_data['dense_lbs_weights'], device=self.device)
self.unique = init_data['unique']
self.vt = init_data['vt']
self.ft = init_data['ft']
model_params = dict(
model_path=os.path.join(workspace_dir, "smplx/SMPLX_NEUTRAL_2020.npz"),
model_type='smplx',
create_global_orient=True,
create_body_pose=True,
create_betas=True,
create_left_hand_pose=True,
create_right_hand_pose=True,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_expression=True,
create_transl=False,
use_pca=False,
flat_hand_mean=False,
num_betas=300,
num_expression_coeffs=100,
num_pca_comps=12,
dtype=torch.float32,
batch_size=1,
)
self.body_model = smplx.create(**model_params).to(device=self.device)
self.smplx_face = self.body_model.faces.astype(np.int32)
ckpt_file = os.path.join(workspace_dir, "MESH", ckpt_path, "params.pt")
albedo_path = os.path.join(workspace_dir, "MESH", ckpt_path, "mesh_albedo.png")
self.load_ckpt_data(ckpt_file, albedo_path)
def load_ckpt_data(self, ckpt_file, albedo_path):
model_data = torch.load(ckpt_file, map_location=self.device)
self.expression = model_data["expression"] if "expression" in model_data else None
self.jaw_pose = model_data["jaw_pose"] if "jaw_pose" in model_data else None
self.betas = model_data['betas']
self.v_offsets = model_data['v_offsets']
self.v_offsets[self.SMPLXSeg.eyeball_ids] = 0.
self.v_offsets[self.SMPLXSeg.hands_ids] = 0.
# tex to trimesh texture
vt = self.vt.copy()
vt[:, 1] = 1 - vt[:, 1]
albedo = Image.open(albedo_path)
self.raw_albedo = torch.from_numpy(np.array(albedo))
self.raw_albedo = self.raw_albedo / 255.0
self.raw_albedo = self.raw_albedo.permute(2, 0, 1)
self.trimesh_visual = trimesh.visual.TextureVisuals(
uv=vt,
image=albedo,
material=trimesh.visual.texture.SimpleMaterial(
image=albedo,
diffuse=[255, 255, 255, 255],
ambient=[255, 255, 255, 255],
specular=[0, 0, 0, 255],
glossiness=0)
)
def forward_mdm(self, motion):
try:
mdm_body_pose = motion["poses"]
translate = torch.from_numpy(motion["trans"])
except:
translate = torch.from_numpy(motion[:, -3:])
mdm_body_pose = motion[:, :-3]
mdm_body_pose = mdm_body_pose.reshape(mdm_body_pose.shape[0], -1, 3)
translate = translate.to(self.device)
scan_v_posed = []
for i, (pose, t) in tqdm(enumerate(zip(mdm_body_pose, translate))):
body_pose = torch.as_tensor(pose[None, 1:22, :], device=self.device)
global_orient = torch.as_tensor(pose[None, :1, :], device=self.device)
output = self.body_model(
betas=self.betas,
global_orient=global_orient,
jaw_pose=self.jaw_pose,
body_pose=body_pose,
expression=self.expression,
return_verts=True
)
v_cano = output.v_posed[0]
# re-mesh
v_cano_dense = subdivide_inorder(v_cano, self.smplx_face[self.SMPLXSeg.remesh_mask], self.unique).squeeze(0)
# add offsets
vn = compute_normal(v_cano_dense, self.dense_faces)[0]
v_cano_dense += self.v_offsets * vn
# do LBS
v_posed_dense = warp_points(v_cano_dense, self.dense_lbs_weights, output.joints_transform[:, :55])
# translate
v_posed_dense += t - translate[0]
scan_v_posed.append(v_posed_dense)
scan_v_posed = torch.cat(scan_v_posed).detach().cpu().numpy()
new_scan_v_posed, new_face = build_new_mesh(scan_v_posed, self.dense_faces, self.vt, self.ft)
new_scan_v_posed = new_scan_v_posed.astype(np.float32)
return new_scan_v_posed, new_face
|