AiOS / detrsmpl /data /datasets /human_image_smplx_dataset.py
ttxskk
update
d7e58f0
import os
import os.path
import pickle
from collections import OrderedDict
from typing import List, Optional, Union
import numpy as np
import torch
from detrsmpl.core.conventions.keypoints_mapping import (
get_keypoint_idx,
get_keypoint_idxs_by_part,
)
from detrsmpl.core.evaluation import fg_vertices_to_mesh_distance
from detrsmpl.utils.transforms import aa_to_rotmat
from .builder import DATASETS
from .human_image_dataset import HumanImageDataset
@DATASETS.register_module()
class HumanImageSMPLXDataset(HumanImageDataset):
# metric
ALLOWED_METRICS = {
'mpjpe', 'pa-mpjpe', 'pve', '3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc',
'3DRMSE', 'pa-pve'
}
def __init__(
self,
data_prefix: str,
pipeline: list,
dataset_name: str,
body_model: Optional[Union[dict, None]] = None,
ann_file: Optional[Union[str, None]] = None,
convention: Optional[str] = 'human_data',
cache_data_path: Optional[Union[str, None]] = None,
test_mode: Optional[bool] = False,
num_betas: Optional[int] = 10,
num_expression: Optional[int] = 10,
face_vertex_ids_path: Optional[str] = None,
hand_vertex_ids_path: Optional[str] = None,
):
super().__init__(data_prefix, pipeline, dataset_name, body_model,
ann_file, convention, cache_data_path, test_mode)
self.num_betas = num_betas
self.num_expression = num_expression
if face_vertex_ids_path is not None:
if os.path.exists(face_vertex_ids_path):
self.face_vertex_ids = np.load(face_vertex_ids_path).astype(
np.int32)
if hand_vertex_ids_path is not None:
if os.path.exists(hand_vertex_ids_path):
with open(hand_vertex_ids_path, 'rb') as f:
vertex_idxs_data = pickle.load(f, encoding='latin1')
self.left_hand_vertex_ids = vertex_idxs_data['left_hand']
self.right_hand_vertex_ids = vertex_idxs_data['right_hand']
def prepare_raw_data(self, idx: int):
"""Get item from self.human_data."""
info = super().prepare_raw_data(idx)
if self.cache_reader is not None:
self.human_data = self.cache_reader.get_item(idx)
idx = idx % self.cache_reader.slice_size
if 'smplx' in self.human_data:
smplx_dict = self.human_data['smplx']
info['has_smplx'] = 1
else:
smplx_dict = {}
info['has_smplx'] = 0
if 'global_orient' in smplx_dict:
info['smplx_global_orient'] = smplx_dict['global_orient'][idx]
info['has_smplx_global_orient'] = 1
else:
info['smplx_global_orient'] = np.zeros((3), dtype=np.float32)
info['has_smplx_global_orient'] = 0
if 'body_pose' in smplx_dict:
info['smplx_body_pose'] = smplx_dict['body_pose'][idx]
info['has_smplx_body_pose'] = 1
else:
info['smplx_body_pose'] = np.zeros((21, 3), dtype=np.float32)
info['has_smplx_body_pose'] = 0
if 'right_hand_pose' in smplx_dict:
info['smplx_right_hand_pose'] = smplx_dict['right_hand_pose'][idx]
info['has_smplx_right_hand_pose'] = 1
else:
info['smplx_right_hand_pose'] = np.zeros((15, 3), dtype=np.float32)
info['has_smplx_right_hand_pose'] = 0
if 'left_hand_pose' in smplx_dict:
info['smplx_left_hand_pose'] = smplx_dict['left_hand_pose'][idx]
info['has_smplx_left_hand_pose'] = 1
else:
info['smplx_left_hand_pose'] = np.zeros((15, 3), dtype=np.float32)
info['has_smplx_left_hand_pose'] = 0
if 'jaw_pose' in smplx_dict:
info['smplx_jaw_pose'] = smplx_dict['jaw_pose'][idx]
info['has_smplx_jaw_pose'] = 1
else:
info['smplx_jaw_pose'] = np.zeros((3), dtype=np.float32)
info['has_smplx_jaw_pose'] = 0
if 'betas' in smplx_dict:
info['smplx_betas'] = smplx_dict['betas'][idx]
info['has_smplx_betas'] = 1
else:
info['smplx_betas'] = np.zeros((self.num_betas), dtype=np.float32)
info['has_smplx_betas'] = 0
if 'expression' in smplx_dict:
info['smplx_expression'] = smplx_dict['expression'][idx]
info['has_smplx_expression'] = 1
else:
info['smplx_expression'] = np.zeros((self.num_expression),
dtype=np.float32)
info['has_smplx_expression'] = 0
return info
def _parse_result(self, res, mode='keypoint', body_part=''):
if mode == 'vertice':
# pred
pred_vertices = res['vertices'] * 1000.
# gt
if 'vertices' in self.human_data: # stirling or ehf
gt_vertices = self.human_data['vertices'].copy()
if self.dataset_name == 'EHF':
gt_vertices = gt_vertices * 1000.
else:
gt_param_dict = self.human_data['smplx'].copy()
for key, value in gt_param_dict.items():
new_value = torch.FloatTensor(value)
if ('pose' in key or key
== 'global_orient') and value.shape[-2] != 3:
new_value = aa_to_rotmat(new_value)
gt_param_dict[key] = new_value
gt_output = self.body_model(**gt_param_dict)
gt_vertices = gt_output['vertices'].detach().cpu().numpy(
) * 1000.
if body_part == 'right_hand':
pred_vertices = pred_vertices[:, self.right_hand_vertex_ids]
gt_vertices = gt_vertices[:, self.right_hand_vertex_ids]
elif body_part == 'left_hand':
pred_vertices = pred_vertices[:, self.left_hand_vertex_ids]
gt_vertices = gt_vertices[:, self.left_hand_vertex_ids]
elif body_part == 'face':
pred_vertices = pred_vertices[:, self.face_vertex_ids]
gt_vertices = gt_vertices[:, self.face_vertex_ids]
gt_mask = np.ones(gt_vertices.shape[:-1])
assert len(pred_vertices) == self.num_data
return pred_vertices, gt_vertices, gt_mask
elif mode == 'keypoint':
pred_keypoints3d = res['keypoints']
assert len(pred_keypoints3d) == self.num_data
if self.dataset_name in {'pw3d', '3DPW', '3dpw'}:
betas = []
body_pose = []
global_orient = []
gender = []
smpl_dict = self.human_data['smpl']
for idx in range(self.num_data):
betas.append(smpl_dict['betas'][idx])
body_pose.append(smpl_dict['body_pose'][idx])
global_orient.append(smpl_dict['global_orient'][idx])
if self.human_data['meta']['gender'][idx] == 'm':
gender.append(0)
else:
gender.append(1)
betas = torch.FloatTensor(betas)
body_pose = torch.FloatTensor(body_pose).view(-1, 69)
global_orient = torch.FloatTensor(global_orient)
gender = torch.Tensor(gender)
gt_output = self.body_model(betas=betas,
body_pose=body_pose,
global_orient=global_orient,
gender=gender)
gt_keypoints3d = gt_output['joints'].detach().cpu().numpy()
gt_keypoints3d_mask = np.ones(
(len(pred_keypoints3d), gt_keypoints3d.shape[1]))
elif self.dataset_name == 'EHF':
gt_vertices = self.human_data['vertices'].copy()
if body_part == 'J14':
gt_keypoints3d = torch.einsum('bik,ji->bjk', [
torch.from_numpy(gt_vertices).float(),
self.body_model.joints_regressor
]).numpy()
pred_vertices = res['vertices']
pred_keypoints3d = torch.einsum('bik,ji->bjk', [
torch.from_numpy(pred_vertices).float(),
self.body_model.joints_regressor
]).numpy()
gt_keypoints3d_mask = np.ones(
(len(pred_keypoints3d), gt_keypoints3d.shape[1]))
else:
gt_keypoints3d = torch.einsum('bik,ji->bjk', [
torch.from_numpy(gt_vertices).float(),
self.body_model.J_regressor
]).numpy()
extra_joints_idxs = np.array([
9120, 9929, 9448, 616, 6, 5770, 5780, 8846, 8463, 8474,
8635, 5361, 4933, 5058, 5169, 5286, 8079, 7669, 7794,
7905, 8022
])
gt_keypoints3d = np.concatenate(
(gt_keypoints3d, gt_vertices[:, extra_joints_idxs]),
axis=1)
pred_vertices = res['vertices']
pred_keypoints3d = torch.einsum('bik,ji->bjk', [
torch.from_numpy(pred_vertices).float(),
self.body_model.J_regressor
]).numpy()
pred_keypoints3d = np.concatenate(
(pred_keypoints3d, pred_vertices[:,
extra_joints_idxs]),
axis=1)
idxs = list(range(0, gt_keypoints3d.shape[1]))
if body_part == 'right_hand':
idxs = get_keypoint_idxs_by_part(
'right_hand', self.convention)
idxs.append(
get_keypoint_idx('right_wrist', self.convention))
elif body_part == 'left_hand':
idxs = get_keypoint_idxs_by_part(
'left_hand', self.convention)
idxs.append(
get_keypoint_idx('left_wrist', self.convention))
elif body_part == 'body':
idxs = get_keypoint_idxs_by_part(
'body', self.convention)
gt_keypoints3d = gt_keypoints3d[:, idxs]
pred_keypoints3d = pred_keypoints3d[:, idxs]
gt_keypoints3d_mask = np.ones(
(len(pred_keypoints3d), gt_keypoints3d.shape[1]))
else:
gt_keypoints3d = self.human_data['keypoints3d'][:, :, :3]
gt_keypoints3d_mask = np.ones(
(len(pred_keypoints3d), gt_keypoints3d.shape[1]))
if gt_keypoints3d.shape[1] == 17:
# SMPLX_to_J14
assert pred_keypoints3d.shape[1] == 14
H36M_TO_J17 = [
6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9
]
H36M_TO_J14 = H36M_TO_J17[:14]
joint_mapper = H36M_TO_J14
gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
pred_pelvis = pred_keypoints3d[:,
[2, 3], :].mean(axis=1,
keepdims=True)
gt_pelvis = gt_keypoints3d[:, [2, 3], :].mean(axis=1,
keepdims=True)
gt_keypoints3d_mask = gt_keypoints3d_mask[:, joint_mapper]
pred_keypoints3d = pred_keypoints3d - pred_pelvis
gt_keypoints3d = gt_keypoints3d - gt_pelvis
elif gt_keypoints3d.shape[1] == 14:
assert pred_keypoints3d.shape[1] == 14
pred_pelvis = pred_keypoints3d[:,
[2, 3], :].mean(axis=1,
keepdims=True)
gt_pelvis = gt_keypoints3d[:, [2, 3], :].mean(axis=1,
keepdims=True)
pred_keypoints3d = pred_keypoints3d - pred_pelvis
gt_keypoints3d = gt_keypoints3d - gt_pelvis
elif gt_keypoints3d.shape[1] == 21:
pred_pelvis = pred_keypoints3d[:, :1, :]
gt_pelvis = gt_keypoints3d[:, :1, :]
pred_keypoints3d = pred_keypoints3d - pred_pelvis
gt_keypoints3d = gt_keypoints3d - gt_pelvis
else:
pass
pred_keypoints3d = pred_keypoints3d * 1000
if self.dataset_name != 'stirling':
gt_keypoints3d = gt_keypoints3d * 1000
gt_keypoints3d_mask = gt_keypoints3d_mask > 0
return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask
def _report_3d_rmse(self, res_file):
"""compute the 3DRMSE between a predicted 3D face shape and the 3D
ground truth scan."""
pred_vertices, gt_vertices, _ = self._parse_result(res_file,
mode='vertice')
pred_keypoints3d, gt_keypoints3d, _ = self._parse_result(
res_file, mode='keypoint')
errors = []
for pred_vertice, gt_vertice, pred_points, gt_points in zip(
pred_vertices, gt_vertices, pred_keypoints3d, gt_keypoints3d):
error = fg_vertices_to_mesh_distance(gt_vertice, gt_points,
pred_vertice,
self.body_model.faces,
pred_points)
errors.append(error)
error = np.array(errors).mean()
name_value_tuples = [('3DRMSE', error)]
return name_value_tuples
def evaluate(self,
outputs: list,
res_folder: str,
metric: Optional[Union[str, List[str]]] = 'pa-mpjpe',
**kwargs: dict):
"""Evaluate 3D keypoint results.
Args:
outputs (list): results from model inference.
res_folder (str): path to store results.
metric (Optional[Union[str, List(str)]]):
the type of metric. Default: 'pa-mpjpe'
kwargs (dict): other arguments.
Returns:
dict:
A dict of all evaluation results.
"""
metrics = metric if isinstance(metric, list) else [metric]
for metric in metrics:
if metric not in self.ALLOWED_METRICS:
raise KeyError(f'metric {metric} is not supported')
# for keeping correctness during multi-gpu test, we sort all results
res_dict = {}
for out in outputs:
target_id = out['image_idx']
batch_size = len(out['keypoints_3d'])
for i in range(batch_size):
res_dict[int(target_id[i])] = dict(
keypoints=out['keypoints_3d'][i],
vertices=out['vertices'][i],
)
keypoints, vertices = [], []
for i in range(self.num_data):
keypoints.append(res_dict[i]['keypoints'])
vertices.append(res_dict[i]['vertices'])
keypoints = np.stack(keypoints)
vertices = np.stack(vertices)
res = dict(keypoints=keypoints, vertices=vertices)
name_value_tuples = []
for index, _metric in enumerate(metrics):
if 'body_part' in kwargs:
body_parts = kwargs['body_part'][index]
for body_part in body_parts:
if _metric == 'pa-mpjpe':
_nv_tuples = self._report_mpjpe(res,
metric='pa-mpjpe',
body_part=body_part)
elif _metric == 'pa-pve':
_nv_tuples = self._report_pve(res,
metric='pa-pve',
body_part=body_part)
else:
raise NotImplementedError
name_value_tuples.extend(_nv_tuples)
else:
if _metric == 'mpjpe':
_nv_tuples = self._report_mpjpe(res)
elif _metric == 'pa-mpjpe':
_nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe')
elif _metric == '3dpck':
_nv_tuples = self._report_3d_pck(res)
elif _metric == 'pa-3dpck':
_nv_tuples = self._report_3d_pck(res, metric='pa-3dpck')
elif _metric == '3dauc':
_nv_tuples = self._report_3d_auc(res)
elif _metric == 'pa-3dauc':
_nv_tuples = self._report_3d_auc(res, metric='pa-3dauc')
elif _metric == 'pve':
_nv_tuples = self._report_pve(res)
elif _metric == 'pa-pve':
_nv_tuples = self._report_pve(res, metric='pa-pve')
elif _metric == '3DRMSE':
_nv_tuples = self._report_3d_rmse(res)
else:
raise NotImplementedError
name_value_tuples.extend(_nv_tuples)
name_value = OrderedDict(name_value_tuples)
return name_value