Spaces:
Runtime error
Runtime error
| import copy | |
| import os | |
| import math | |
| from scipy.optimize import linear_sum_assignment | |
| from typing import List | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torchvision.ops.boxes import nms | |
| from torch import Tensor | |
| from pycocotools.coco import COCO | |
| from util import box_ops | |
| from util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy, | |
| get_world_size, interpolate, | |
| is_dist_avail_and_initialized, inverse_sigmoid) | |
| from detrsmpl.utils.demo_utils import convert_verts_to_cam_coord, xywh2xyxy, xyxy2xywh | |
| import numpy as np | |
| from detrsmpl.core.conventions.keypoints_mapping import convert_kps | |
| from detrsmpl.models.body_models.builder import build_body_model | |
| from detrsmpl.utils.geometry import batch_rodrigues, project_points, weak_perspective_projection,project_points_new | |
| from util.human_models import smpl_x | |
| from detrsmpl.core.conventions.keypoints_mapping import get_keypoint_idx | |
| class PostProcess(nn.Module): | |
| """This module converts the model's output into the format expected by the | |
| coco api.""" | |
| def __init__(self, | |
| num_select=100, | |
| nms_iou_threshold=-1, | |
| num_body_points=17, | |
| body_model=None) -> None: | |
| super().__init__() | |
| self.num_select = num_select | |
| self.nms_iou_threshold = nms_iou_threshold | |
| self.num_body_points = num_body_points | |
| self.body_model = build_body_model( | |
| dict(type='GenderedSMPL', | |
| keypoint_src='h36m', | |
| keypoint_dst='h36m', | |
| model_path='data/body_models/smpl', | |
| keypoint_approximate=True, | |
| joints_regressor= | |
| 'data/body_models/J_regressor_h36m.npy')) | |
| def forward(self, | |
| outputs, | |
| target_sizes, | |
| targets, | |
| data_batch_nc, | |
| device, | |
| not_to_xyxy=False, | |
| test=False): | |
| # import pdb; pdb.set_trace() | |
| num_select = self.num_select | |
| self.body_model.to(device) | |
| out_logits, out_bbox, out_keypoints= \ | |
| outputs['pred_logits'], outputs['pred_boxes'], \ | |
| outputs['pred_keypoints'] | |
| out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d = \ | |
| outputs['pred_smpl_pose'], outputs['pred_smpl_beta'], \ | |
| outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'] | |
| assert len(out_logits) == len(target_sizes) | |
| assert target_sizes.shape[1] == 2 | |
| prob = out_logits.sigmoid() | |
| topk_values, topk_indexes = \ | |
| torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1) | |
| scores = topk_values | |
| # bbox | |
| topk_boxes = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| if not_to_xyxy: | |
| boxes = out_bbox | |
| else: | |
| boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
| if test: | |
| assert not not_to_xyxy | |
| boxes[:, :, 2:] = boxes[:, :, 2:] - boxes[:, :, :2] | |
| boxes_norm = torch.gather(boxes, 1, | |
| topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) | |
| target_sizes = target_sizes.type_as(boxes) | |
| # from relative [0, 1] to absolute [0, height] coordinates | |
| img_h, img_w = target_sizes.unbind(1) | |
| scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
| boxes = boxes_norm * scale_fct[:, None, :] | |
| # keypoints | |
| topk_keypoints = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| keypoints = torch.gather( | |
| out_keypoints, 1, | |
| topk_keypoints.unsqueeze(-1).repeat(1, 1, | |
| self.num_body_points * 3)) | |
| Z_pred = keypoints[:, :, :(self.num_body_points * 2)] | |
| V_pred = keypoints[:, :, (self.num_body_points * 2):] | |
| img_h, img_w = target_sizes.unbind(1) | |
| Z_pred = Z_pred * torch.stack([img_w, img_h], dim=1).repeat( | |
| 1, self.num_body_points)[:, None, :] | |
| keypoints_res = torch.zeros_like(keypoints) | |
| keypoints_res[..., 0::3] = Z_pred[..., 0::2] | |
| keypoints_res[..., 1::3] = Z_pred[..., 1::2] | |
| keypoints_res[..., 2::3] = V_pred[..., 0::1] | |
| # smpl out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d | |
| topk_smpl = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| smpl_pose = torch.gather( | |
| out_smpl_pose, 1, topk_smpl[:, :, None, None, | |
| None].repeat(1, 1, 24, 3, 3)) | |
| smpl_beta = torch.gather(out_smpl_beta, 1, | |
| topk_smpl[:, :, None].repeat(1, 1, 10)) | |
| smpl_cam = torch.gather(out_smpl_cam, 1, | |
| topk_smpl[:, :, None].repeat(1, 1, 3)) | |
| smpl_kp3d = torch.gather( | |
| out_smpl_kp3d, 1, | |
| topk_smpl[:, :, None, None].repeat(1, 1, out_smpl_kp3d.shape[-2], | |
| 3)) | |
| if False: | |
| import cv2 | |
| import mmcv | |
| img = cv2.imread(data_batch_nc['img_metas'][0]['image_path']) | |
| render_img = mmcv.imshow_bboxes(img.copy(), | |
| boxes[0][:3].cpu().numpy(), | |
| show=False) | |
| cv2.imwrite('r_bbox.png', render_img) | |
| gt_bbox_xyxy = xywh2xyxy( | |
| data_batch_nc['bbox_xywh'][0].cpu().numpy()) | |
| render_img = mmcv.imshow_bboxes(img.copy(), | |
| gt_bbox_xyxy, | |
| show=False) | |
| cv2.imwrite('r_bbox.png', render_img) | |
| from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d | |
| visualize_kp3d(smpl_kp3d[0][[0]].cpu().numpy(), | |
| output_path='.', | |
| data_source='smpl_54') | |
| from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
| visualize_kp2d(keypoints_res[0].reshape(-1, 17, | |
| 3)[[0]].cpu().numpy(), | |
| output_path='.', | |
| image_array=img.copy()[None], | |
| data_source='coco', | |
| overwrite=True) | |
| tgt_smpl_kp3d = data_batch_nc['keypoints3d_smpl'] | |
| tgt_smpl_pose = [ | |
| torch.concat([ | |
| data_batch_nc['smpl_global_orient'][i][:, None], | |
| data_batch_nc['smpl_body_pose'][i] | |
| ], | |
| dim=-2) | |
| for i in range(len(data_batch_nc['smpl_body_pose'])) | |
| ] | |
| tgt_smpl_beta = data_batch_nc['smpl_betas'] | |
| tgt_keypoints = data_batch_nc['keypoints2d_ori'] | |
| tgt_bbox = data_batch_nc['bbox_xywh'] | |
| indices = [] | |
| # pred | |
| pred_smpl_kp3d = [] | |
| pred_smpl_pose = [] | |
| pred_smpl_beta = [] | |
| pred_scores = [] | |
| pred_labels = [] | |
| pred_boxes = [] | |
| pred_keypoints = [] | |
| pred_smpl_cam = [] | |
| # gt | |
| gt_smpl_kp3d = [] | |
| gt_smpl_pose = [] | |
| gt_smpl_beta = [] | |
| gt_boxes = [] | |
| gt_keypoints = [] | |
| image_idx = [] | |
| results = [] | |
| for i, kp3d in enumerate(tgt_smpl_kp3d): | |
| # kp3d | |
| conf = tgt_smpl_kp3d[i][..., [3]] | |
| gt_kp3d = tgt_smpl_kp3d[i][..., :3] | |
| pred_kp3d = smpl_kp3d[i] | |
| gt_output = self.body_model( | |
| betas=tgt_smpl_beta[i].float(), | |
| body_pose=tgt_smpl_pose[i][:, 1:].float().reshape(-1, 69), | |
| global_orient=tgt_smpl_pose[i][:, [0]].float().reshape(-1, 3), | |
| gender=torch.zeros(tgt_smpl_beta[i].shape[0]), | |
| pose2rot=True) | |
| gt_kp3d = gt_output['joints'] | |
| # gt_kp3d,_ = convert_kps( | |
| # gt_kp3d, | |
| # src='smpl_54', | |
| # dst='h36m', | |
| # ) | |
| assert gt_kp3d.shape[-2] == 17 | |
| H36M_TO_J17 = [ | |
| 6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9 | |
| ] | |
| H36M_TO_J14 = H36M_TO_J17[:14] | |
| joint_mapper = H36M_TO_J14 | |
| pred_pelvis = pred_kp3d[:, 0] | |
| gt_pelvis = gt_kp3d[:, 0] | |
| gt_keypoints3d = gt_kp3d[:, joint_mapper, :] | |
| pred_keypoints3d = pred_kp3d[:, joint_mapper, :] | |
| pred_keypoints3d = (pred_keypoints3d - | |
| pred_pelvis[:, None, :]) * 1000 | |
| gt_keypoints3d = (gt_keypoints3d - gt_pelvis[:, None, :]) * 1000 | |
| cost_kp3d = torch.abs((pred_keypoints3d[:, None] - | |
| gt_keypoints3d[None])).sum([-2, -1]) | |
| tgt_bbox[i][..., 2] = tgt_bbox[i][..., 0] + tgt_bbox[i][..., 2] | |
| tgt_bbox[i][..., 3] = tgt_bbox[i][..., 1] + tgt_bbox[i][..., 3] | |
| gt_bbox = tgt_bbox[i][..., :4].float() | |
| pred_bbox = boxes[i] | |
| # box_iou = box_ops.box_iou(pred_bbox,gt_bbox)[0] | |
| cost_giou = -box_ops.generalized_box_iou(pred_bbox, gt_bbox) | |
| indice = linear_sum_assignment(cost_giou.cpu()) | |
| pred_ind, gt_ind = indice | |
| indices.append(indice) | |
| # bbox | |
| # cost_bbox = torch.cdist(pred_bbox, gt_bbox, p=1) | |
| # indice = linear_sum_assignment(cost_giou.cpu()) | |
| # pred_ind, gt_ind = indice | |
| # indices.append(indice) | |
| # pred | |
| pred_scores.append(scores[i][pred_ind].detach().cpu().numpy()) | |
| pred_labels.append(labels[i][pred_ind].detach().cpu().numpy()) | |
| pred_boxes.append(boxes[i][pred_ind].detach().cpu().numpy()) | |
| pred_keypoints.append( | |
| keypoints_res[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_kp3d.append( | |
| smpl_kp3d[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_pose.append( | |
| smpl_pose[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_beta.append( | |
| smpl_beta[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_cam.append(smpl_cam[i][pred_ind].detach().cpu().numpy()) | |
| # gt | |
| gt_smpl_kp3d.append( | |
| tgt_smpl_kp3d[i][gt_ind].detach().cpu().numpy()) | |
| gt_smpl_pose.append( | |
| tgt_smpl_pose[i][gt_ind].detach().cpu().numpy()) | |
| gt_smpl_beta.append( | |
| tgt_smpl_beta[i][gt_ind].detach().cpu().numpy()) | |
| gt_boxes.append(tgt_bbox[i][gt_ind].detach().cpu().numpy()) | |
| gt_keypoints.append( | |
| tgt_keypoints[i][gt_ind].detach().cpu().numpy()) | |
| image_idx.append(targets[i]['image_id'].detach().cpu().numpy()) | |
| # gt_output = self.body_model( | |
| # betas=tgt_smpl_beta[i].float(), | |
| # body_pose=tgt_smpl_pose[i][:,1:].float().reshape(-1, 69), | |
| # global_orient=tgt_smpl_pose[i][:,[0]].float().reshape(-1, 3), | |
| # pose2rot=True | |
| # ) | |
| results.append({ | |
| 'scores': pred_scores, | |
| 'labels': pred_labels, | |
| 'boxes': pred_boxes, | |
| 'keypoints': pred_keypoints, | |
| 'pred_smpl_pose': pred_smpl_pose, | |
| 'pred_smpl_beta': pred_smpl_beta, | |
| 'pred_smpl_cam': pred_smpl_cam, | |
| 'pred_smpl_kp3d': pred_smpl_kp3d, | |
| 'gt_smpl_pose': gt_smpl_pose, | |
| 'gt_smpl_beta': gt_smpl_beta, | |
| 'gt_smpl_kp3d': gt_smpl_kp3d, | |
| 'gt_boxes': gt_bbox, | |
| 'gt_keypoints': gt_keypoints, | |
| 'image_idx': image_idx, | |
| }) | |
| # results.append({ | |
| # 'scores': scores[i][pred_ind], | |
| # 'labels': labels[i][pred_ind], | |
| # 'boxes': boxes[i][pred_ind], | |
| # 'keypoints': keypoints_res[i][pred_ind], | |
| # 'pred_smpl_pose': smpl_pose[i][pred_ind], | |
| # 'pred_smpl_beta': tgt_smpl_beta[i][gt_ind], | |
| # 'pred_smpl_cam': smpl_cam[i][pred_ind], | |
| # 'pred_smpl_kp3d': smpl_kp3d[i][pred_ind], | |
| # 'gt_smpl_pose': tgt_smpl_pose[i][gt_ind], | |
| # 'gt_smpl_beta': tgt_smpl_beta[i][gt_ind], | |
| # 'gt_smpl_kp3d': tgt_smpl_kp3d[i][gt_ind], | |
| # 'gt_boxes': tgt_bbox[i][gt_ind], | |
| # 'gt_keypoints': tgt_keypoints[i][gt_ind], | |
| # 'image_idx': targets[i]['image_id'], | |
| # } | |
| # ) | |
| if self.nms_iou_threshold > 0: | |
| raise NotImplementedError | |
| item_indices = [ | |
| nms(b, s, iou_threshold=self.nms_iou_threshold) | |
| for b, s in zip(boxes, scores) | |
| ] | |
| # import pdb; pdb.set_trace() | |
| results = [{ | |
| 'scores': s[i], | |
| 'labels': l[i], | |
| 'boxes': b[i] | |
| } for s, l, b, i in zip(scores, labels, boxes, item_indices)] | |
| else: | |
| results = results | |
| return results | |
| class PostProcess_aios(nn.Module): | |
| """This module converts the model's output into the format expected by the | |
| coco api.""" | |
| def __init__(self, | |
| num_select=100, | |
| nms_iou_threshold=-1, | |
| num_body_points=17) -> None: | |
| super().__init__() | |
| self.num_select = num_select | |
| self.nms_iou_threshold = nms_iou_threshold | |
| self.num_body_points = num_body_points | |
| def forward(self, outputs, target_sizes, not_to_xyxy=False, test=False): | |
| num_select = self.num_select | |
| out_logits, out_bbox, out_keypoints = outputs['pred_logits'], outputs[ | |
| 'pred_boxes'], outputs['pred_keypoints'] | |
| assert len(out_logits) == len(target_sizes) | |
| assert target_sizes.shape[1] == 2 | |
| prob = out_logits.sigmoid() | |
| topk_values, topk_indexes = torch.topk(prob.view( | |
| out_logits.shape[0], -1), | |
| num_select, | |
| dim=1) | |
| scores = topk_values | |
| # bbox | |
| topk_boxes = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| if not_to_xyxy: | |
| boxes = out_bbox | |
| else: | |
| boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
| if test: | |
| assert not not_to_xyxy | |
| boxes[:, :, 2:] = boxes[:, :, 2:] - boxes[:, :, :2] | |
| boxes = torch.gather(boxes, 1, | |
| topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) | |
| # from relative [0, 1] to absolute [0, height] coordinates | |
| img_h, img_w = target_sizes.unbind(1) | |
| scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
| boxes = boxes * scale_fct[:, None, :] | |
| # keypoints | |
| topk_keypoints = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| keypoints = torch.gather( | |
| out_keypoints, 1, | |
| topk_keypoints.unsqueeze(-1).repeat(1, 1, | |
| self.num_body_points * 3)) | |
| Z_pred = keypoints[:, :, :(self.num_body_points * 2)] | |
| V_pred = keypoints[:, :, (self.num_body_points * 2):] | |
| img_h, img_w = target_sizes.unbind(1) | |
| Z_pred = Z_pred * torch.stack([img_w, img_h], dim=1).repeat( | |
| 1, self.num_body_points)[:, None, :] | |
| keypoints_res = torch.zeros_like(keypoints) | |
| keypoints_res[..., 0::3] = Z_pred[..., 0::2] | |
| keypoints_res[..., 1::3] = Z_pred[..., 1::2] | |
| keypoints_res[..., 2::3] = V_pred[..., 0::1] | |
| if self.nms_iou_threshold > 0: | |
| raise NotImplementedError | |
| item_indices = [ | |
| nms(b, s, iou_threshold=self.nms_iou_threshold) | |
| for b, s in zip(boxes, scores) | |
| ] | |
| # import ipdb; ipdb.set_trace() | |
| results = [{ | |
| 'scores': s[i], | |
| 'labels': l[i], | |
| 'boxes': b[i] | |
| } for s, l, b, i in zip(scores, labels, boxes, item_indices)] | |
| else: | |
| results = [{ | |
| 'scores': s, | |
| 'labels': l, | |
| 'boxes': b, | |
| 'keypoints': k | |
| } for s, l, b, k in zip(scores, labels, boxes, keypoints_res)] | |
| return results | |
| class PostProcess_SMPLX(nn.Module): | |
| """ This module converts the model's output into the format expected by the coco api""" | |
| def __init__( | |
| self, | |
| num_select=100, | |
| nms_iou_threshold=-1, | |
| num_body_points=17, | |
| body_model= dict( | |
| type='smplx', | |
| keypoint_src='smplx', | |
| num_expression_coeffs=10, | |
| keypoint_dst='smplx_137', | |
| model_path='data/body_models/smplx', | |
| use_pca=False, | |
| use_face_contour=True) | |
| ) -> None: | |
| super().__init__() | |
| self.num_select = num_select | |
| self.nms_iou_threshold = nms_iou_threshold | |
| self.num_body_points=num_body_points | |
| self.body_model = build_body_model(body_model) | |
| def forward(self, outputs, target_sizes, targets, data_batch_nc, not_to_xyxy=False, test=False): | |
| # import pdb; pdb.set_trace() | |
| num_select = self.num_select | |
| out_logits, out_bbox, out_keypoints= \ | |
| outputs['pred_logits'], outputs['pred_boxes'], \ | |
| outputs['pred_keypoints'] | |
| out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \ | |
| outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \ | |
| outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts'] | |
| assert len(out_logits) == len(target_sizes) | |
| assert target_sizes.shape[1] == 2 | |
| prob = out_logits.sigmoid() | |
| topk_values, topk_indexes = \ | |
| torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1) | |
| scores = topk_values | |
| # bbox | |
| topk_boxes = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| if not_to_xyxy: | |
| boxes = out_bbox | |
| else: | |
| boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
| if test: | |
| assert not not_to_xyxy | |
| boxes[:,:,2:] = boxes[:,:,2:] - boxes[:,:,:2] | |
| boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| target_sizes = target_sizes.type_as(boxes) | |
| # from relative [0, 1] to absolute [0, height] coordinates | |
| img_h, img_w = target_sizes.unbind(1) | |
| scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
| boxes = boxes_norm * scale_fct[:, None, :] | |
| # keypoints | |
| topk_keypoints = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| keypoints = torch.gather(out_keypoints, 1, topk_keypoints.unsqueeze(-1).repeat(1, 1, self.num_body_points*3)) | |
| Z_pred = keypoints[:, :, :(self.num_body_points*2)] | |
| V_pred = keypoints[:, :, (self.num_body_points*2):] | |
| img_h, img_w = target_sizes.unbind(1) | |
| Z_pred = Z_pred * torch.stack([img_w, img_h], dim=1).repeat(1, self.num_body_points)[:, None, :] | |
| keypoints_res = torch.zeros_like(keypoints) | |
| keypoints_res[..., 0::3] = Z_pred[..., 0::2] | |
| keypoints_res[..., 1::3] = Z_pred[..., 1::2] | |
| keypoints_res[..., 2::3] = V_pred[..., 0::1] | |
| # smpl out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d | |
| topk_smpl = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159)) | |
| smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3)) | |
| smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3)) | |
| smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3)) | |
| if False: | |
| import cv2 | |
| import mmcv | |
| import ipdb;ipdb.set_trace() | |
| img = (data_batch_nc['img'][1].permute(1,2,0)*255).int().detach().cpu().numpy() | |
| # img = cv2.imread(data_batch_nc['img_metas'][1]['image_path']) | |
| tgt_bbox_center = torch.stack(data_batch_nc['body_bbox_center']) | |
| tgt_bbox_size = torch.stack(data_batch_nc['body_bbox_size']).cpu().numpy() | |
| tgt_bbox = torch.cat([tgt_bbox_center-tgt_bbox_size/2,tgt_bbox_center+tgt_bbox_size/2],dim=-1) | |
| tgt_img_shape = data_batch_nc['img_shape'] | |
| bbox = tgt_bbox.cpu().numpy()*(tgt_img_shape.repeat(1,2).cpu().numpy()[:,::-1]) | |
| render_img = mmcv.imshow_bboxes(img.copy(), boxes[1][:3].cpu().numpy(), show=False) | |
| cv2.imwrite('r_bbox.png',render_img) | |
| render_img = mmcv.imshow_bboxes(img.copy(), bbox, show=False) | |
| # cv2.imwrite('r_bbox.png',render_img) | |
| # from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d | |
| # visualize_kp3d(smpl_kp3d[1][[0]].cpu().numpy(),output_path='.',data_source='smpl_54') | |
| from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
| import ipdb;ipdb.set_trace() | |
| visualize_kp2d(keypoints_res[0].reshape(-1,17,3)[[3]].cpu().numpy(), output_path='.', image_array=img.copy()[None], data_source='coco',overwrite=True) | |
| # TODO: align it with agora | |
| tgt_smpl_kp3d = data_batch_nc['joint_cam'] | |
| tgt_smpl_kp3d_conf = data_batch_nc['joint_valid'] | |
| tgt_smpl_pose = data_batch_nc['smplx_pose'] | |
| tgt_smpl_beta = data_batch_nc['smplx_shape'] | |
| tgt_smpl_expr = data_batch_nc['smplx_expr'] | |
| tgt_keypoints = data_batch_nc['joint_img'] | |
| tgt_img_shape = data_batch_nc['img_shape'] | |
| tgt_ann_idx = data_batch_nc['ann_idx'] | |
| # tgt_img_path = data_batch_nc['img_shape'] | |
| tgt_bbox_center = torch.stack(data_batch_nc['body_bbox_center']) | |
| tgt_bbox_size = torch.stack(data_batch_nc['body_bbox_size']) | |
| tgt_bbox = torch.cat([tgt_bbox_center-tgt_bbox_size/2,tgt_bbox_size],dim=-1) | |
| tgt_bbox = tgt_bbox * scale_fct | |
| tgt_verts = data_batch_nc['smplx_mesh_cam'] | |
| tgt_bb2img_trans = data_batch_nc['bb2img_trans'] | |
| indices = [] | |
| # pred | |
| pred_smpl_kp3d = [] | |
| pred_smpl_pose = [] | |
| pred_smpl_beta = [] | |
| pred_smpl_verts = [] | |
| pred_smpl_expr = [] | |
| pred_scores = [] | |
| pred_labels = [] | |
| pred_boxes = [] | |
| pred_keypoints = [] | |
| pred_smpl_cam = [] | |
| # gt | |
| gt_smpl_kp3d = [] | |
| gt_smpl_pose = [] | |
| gt_smpl_beta = [] | |
| gt_smpl_expr = [] | |
| gt_smpl_verts = [] | |
| gt_boxes = [] | |
| gt_keypoints = [] | |
| gt_bb2img_trans = [] | |
| image_idx = [] | |
| results = [] | |
| for i, kp3d in enumerate(tgt_smpl_kp3d): | |
| # kp3d | |
| conf = tgt_smpl_kp3d_conf[i][...,] | |
| gt_kp3d = tgt_smpl_kp3d[i][...,:3] | |
| pred_kp3d = smpl_kp3d[i] | |
| pred_kp3d_match,_ = convert_kps(pred_kp3d,'smplx','smplx_137') | |
| # pred_kp3d_match = pred_kp3d | |
| cost_kp3d = torch.abs((pred_kp3d_match[:,None] - | |
| gt_kp3d[None])* conf[None]).sum([-2,-1]) | |
| # bbox | |
| tgt_bbox[i][...,2] = tgt_bbox[i][...,0] + tgt_bbox[i][...,2] | |
| tgt_bbox[i][...,3] = tgt_bbox[i][...,1] + tgt_bbox[i][...,3] | |
| gt_bbox = tgt_bbox[i][..., :4][None].float() | |
| pred_bbox = boxes[i] | |
| # box_iou = box_ops.box_iou(pred_bbox,gt_bbox)[0] | |
| cost_giou = -box_ops.generalized_box_iou(pred_bbox,gt_bbox) | |
| # cost_bbox = torch.cdist(pred_bbox, gt_bbox, p=1) | |
| indice = linear_sum_assignment(cost_kp3d.cpu()) | |
| pred_ind, gt_ind = indice | |
| indices=(indice) | |
| # pred | |
| pred_scores=(scores[i][pred_ind].detach().cpu().numpy()) | |
| pred_labels=(labels[i][pred_ind].detach().cpu().numpy()) | |
| pred_boxes=(boxes[i][pred_ind].detach().cpu().numpy()) | |
| pred_keypoints=(keypoints_res[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_kp3d=(smpl_kp3d[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_pose=(smpl_pose[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_beta=(smpl_beta[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_cam=(smpl_cam[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_expr=(smpl_expr[i][pred_ind].detach().cpu().numpy()) | |
| pred_smpl_verts=(smpl_verts[i][pred_ind].detach().cpu().numpy()) | |
| # gt | |
| # gt_smpl_kp3d=(tgt_smpl_kp3d[i][gt_ind].detach().cpu().numpy()) | |
| # gt_smpl_pose=(tgt_smpl_pose[i][gt_ind].detach().cpu().numpy()) | |
| # gt_smpl_beta=(tgt_smpl_beta[i][gt_ind].detach().cpu().numpy()) | |
| # gt_boxes=(tgt_bbox[i][gt_ind].detach().cpu().numpy()) | |
| # gt_smpl_expr=(tgt_smpl_expr[i][gt_ind].detach().cpu().numpy()) | |
| # gt_smpl_verts=(tgt_verts[i][gt_ind].detach().cpu().numpy()) | |
| # gt_keypoints=(tgt_keypoints[i][gt_ind].detach().cpu().numpy()) | |
| # gt_bb2img_trans=(tgt_bb2img_trans[i][gt_ind].detach().cpu().numpy()) | |
| gt_smpl_kp3d=(tgt_smpl_kp3d[i].detach().cpu().numpy()) | |
| gt_smpl_pose=(tgt_smpl_pose[i].detach().cpu().numpy()) | |
| gt_smpl_beta=(tgt_smpl_beta[i].detach().cpu().numpy()) | |
| gt_boxes=(tgt_bbox[i].detach().cpu().numpy()) | |
| gt_smpl_expr=(tgt_smpl_expr[i].detach().cpu().numpy()) | |
| gt_smpl_verts=(tgt_verts[i].detach().cpu().numpy()) | |
| gt_ann_idx=(tgt_ann_idx[i].detach().cpu().numpy()) | |
| gt_keypoints=(tgt_keypoints[i].detach().cpu().numpy()) | |
| gt_img_shape=(tgt_img_shape[i].detach().cpu().numpy()) | |
| gt_bb2img_trans=(tgt_bb2img_trans[i].detach().cpu().numpy()) | |
| if 'image_id' in targets[i]: | |
| image_idx=(targets[i]['image_id'].detach().cpu().numpy()) | |
| # pred_smpl_pose = np.concatenate(pred_smpl_pose,axis = 0) | |
| # gt_bb2img_trans = np.concatenate(gt_bb2img_trans,axis = 0) | |
| # gt_smpl_verts = np.concatenate(gt_smpl_verts,axis = 0) | |
| # pred_smpl_verts = np.concatenate(pred_smpl_verts, axis = 0) | |
| # pred_smpl_cam = np.concatenate(pred_smpl_cam, axis = 0) | |
| # import ipdb;ipdb.set_trace() | |
| smplx_root_pose = pred_smpl_pose[:,:3] | |
| smplx_body_pose = pred_smpl_pose[:,3:66] | |
| smplx_lhand_pose = pred_smpl_pose[:,66:111] | |
| smplx_rhand_pose = pred_smpl_pose[:,111:156] | |
| smplx_jaw_pose = pred_smpl_pose[:,156:] | |
| # pred_smpl_kp3d = np.concatenate(pred_smpl_kp3d,axis = 0) | |
| pred_smpl_cam = torch.Tensor(pred_smpl_cam) | |
| pred_smpl_kp3d = torch.Tensor(pred_smpl_kp3d) | |
| # pred_smpl_kp2d = weak_perspective_projection(pred_smpl_kp3d, scale=pred_smpl_cam[:, :1], translation=pred_smpl_cam[:, 1:3]) | |
| # pred_smpl_verts2d = weak_perspective_projection(pred_smpl_kp3d, scale=pred_smpl_cam[:, :1], translation=pred_smpl_cam[:, 1:3]) | |
| img_wh = tgt_img_shape[i].flip(-1)[None] | |
| pred_smpl_kp2d = project_points_new( | |
| points_3d=pred_smpl_kp3d, | |
| pred_cam=pred_smpl_cam, | |
| focal_length=5000, | |
| camera_center=img_wh/2 | |
| ) | |
| pred_smpl_kp2d = pred_smpl_kp2d.numpy() | |
| pred_smpl_cam = pred_smpl_cam.numpy() | |
| # cam_trans = get_camera_trans(pred_smpl_cam) | |
| # pred_smpl_kp2d = (pred_smpl_kp2d+1)/2 | |
| # pred_smpl_kp2d[:, :,0] = pred_smpl_kp2d[:, :, 0] * gt_img_shape[1] | |
| # pred_smpl_kp2d[:, :, 1] = pred_smpl_kp2d[:, :, 1] * gt_img_shape[0] | |
| # # joint_proj = np.dot(out['bb2img_trans'], joint_proj.transpose(1, 0)).transpose(1, 0) | |
| # # joint_proj[:, 0] = joint_proj[:, 0] / self.resolution[1] * 3840 # restore to original resolution | |
| # # joint_proj[:, 1] = joint_proj[:, 1] / self.resolution[0] * 2160 # restore to original resolution | |
| vis = False | |
| if vis: | |
| from pytorch3d.io import save_obj | |
| from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
| from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl | |
| from detrsmpl.utils.demo_utils import get_default_hmr_intrinsic | |
| # img = (data_batch_nc['img'][i]*255).permute(1,2,0).int().detach().cpu().numpy() | |
| # (s, tx, ty) = (pred_smpl_cam[:, 0] + 1e-9), pred_smpl_cam[:, 1], pred_smpl_cam[:, 2] | |
| # depth, dx, dy = 1./s, tx/s, ty/s | |
| # cam_t = np.stack([dx, dy, depth], 1) | |
| # K = torch.Tensor( | |
| # get_default_hmr_intrinsic(focal_length=5000, | |
| # det_height=750, | |
| # det_width=1333)) | |
| # render_smpl(verts = pred_smpl_verts+cam_t[:,None,:], | |
| # image_array=img.copy()[None], | |
| # body_model=self.body_model,convention='opencv', | |
| # output_path='.',overwrite=True,K=K) | |
| # save_obj( | |
| # 'pred.obj', | |
| # torch.tensor(pred_smpl_verts[0]), | |
| # torch.tensor(self.body_model.faces.astype(np.float))) | |
| import mmcv | |
| import cv2 | |
| import numpy as np | |
| from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
| from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl | |
| from detrsmpl.models.body_models.builder import build_body_model | |
| from pytorch3d.io import save_obj | |
| from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d | |
| img = mmcv.imdenormalize( | |
| img=(data_batch_nc['img'][i].cpu().numpy()).transpose(1, 2, 0), | |
| mean=np.array([123.675, 116.28, 103.53]), | |
| std=np.array([58.395, 57.12, 57.375]), | |
| to_bgr=True).astype(np.uint8) | |
| img = mmcv.imshow_bboxes(img,pred_boxes,show=False) | |
| img= visualize_kp2d(pred_smpl_kp2d, output_path='.', image_array=img.copy()[None], data_source='smplx',overwrite=True)[0] | |
| name = str(pred_smpl_kp2d[0,0,0]).replace('.','') | |
| cv2.imwrite('res_vis/%s.png'%name, img) | |
| # # joint_proj = np.dot(out['bb2img_trans'], joint_proj.transpose(1, 0)).transpose(1, 0) | |
| # # joint_proj[:, 0] = joint_proj[:, 0] / self.resolution[1] * 3840 # restore to original resolution | |
| # # joint_proj[:, 1] = joint_proj[:, 1] / self.resolution[0] * 2160 # restore to original resolution | |
| # | |
| # from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
| # from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl | |
| # from detrsmpl.models.body_models.builder import build_body_model | |
| # body_model = dict( | |
| # type='smplx', | |
| # keypoint_src='smplx', | |
| # num_expression_coeffs=10, | |
| # keypoint_dst='smplx_137', | |
| # model_path='data/body_models/smplx', | |
| # use_pca=False, | |
| # use_face_contour=True) | |
| # body_modeltest = build_body_model(body_model) | |
| # # device =gt_betas.device | |
| # # body_modeltest.to(device) | |
| # gt_output = body_modeltest(betas=torch.Tensor(gt_smpl_beta[None].reshape(-1, 10)),body_pose=torch.Tensor(gt_smpl_pose[3:66][None].reshape(-1, 21*3)), global_orient=torch.Tensor(gt_smpl_pose[:3][None].reshape(-1, 3)),left_hand_pose=torch.Tensor(gt_smpl_pose[66:111][None].reshape(-1, 15*3)),right_hand_pose=torch.Tensor(gt_smpl_pose[111:156][None].reshape(-1, 15*3)),jaw_pose=torch.Tensor(gt_smpl_pose[156:][None].reshape(-1, 3)),) | |
| # img = (data_batch_nc['img'][i]*255).permute(1,2,0).int().detach().cpu().numpy() | |
| # render_smpl(verts = gt_output['vertices'],image_array=img.copy()[None],body_model=self.body_model,convention='opencv',orig_cam = np.concatenate([pred_smpl_cam[:,:1],pred_smpl_cam[:,:1],pred_smpl_cam[:,1:]],axis=-1),output_path='.',overwrite=True) | |
| # img_new = visualize_smpl_hmr( | |
| # cam_transl=pred_smpl_cam, | |
| # verts = pred_smpl_verts, | |
| # body_model=self.body_model, | |
| # bbox = np.array([0,0,gt_img_shape[1],gt_img_shape[0]]), | |
| # det_width = gt_img_shape[1], | |
| # det_height=gt_img_shape[0], | |
| # image_array=img.copy()[None], | |
| # output_path='.', | |
| # overwrite=True | |
| # ) | |
| results.append({ | |
| 'scores': pred_scores, | |
| 'labels': pred_labels, | |
| 'boxes': pred_boxes[0], | |
| 'keypoints': pred_keypoints[0], | |
| 'smplx_root_pose': smplx_root_pose[0], | |
| 'smplx_body_pose': smplx_body_pose[0], | |
| 'smplx_lhand_pose': smplx_lhand_pose[0], | |
| 'smplx_rhand_pose': smplx_rhand_pose[0], | |
| 'smplx_jaw_pose': smplx_jaw_pose[0], | |
| 'smplx_shape': pred_smpl_beta[0], | |
| 'smplx_expr': pred_smpl_expr[0], | |
| 'cam_trans': pred_smpl_cam[0], | |
| 'smplx_mesh_cam': pred_smpl_verts[0], | |
| 'smplx_mesh_cam_target': gt_smpl_verts, | |
| 'gt_ann_idx':gt_ann_idx, | |
| 'gt_smpl_kp3d':gt_smpl_kp3d, | |
| 'smplx_joint_proj': pred_smpl_kp2d[0], | |
| 'image_idx': image_idx, | |
| 'bb2img_trans': gt_bb2img_trans, | |
| 'img_shape': gt_img_shape | |
| }) | |
| # results.append({ | |
| # 'scores': scores[i][pred_ind], | |
| # 'labels': labels[i][pred_ind], | |
| # 'boxes': boxes[i][pred_ind], | |
| # 'keypoints': keypoints_res[i][pred_ind], | |
| # 'pred_smpl_pose': smpl_pose[i][pred_ind], | |
| # 'pred_smpl_beta': tgt_smpl_beta[i][gt_ind], | |
| # 'pred_smpl_cam': smpl_cam[i][pred_ind], | |
| # 'pred_smpl_kp3d': smpl_kp3d[i][pred_ind], | |
| # 'gt_smpl_pose': tgt_smpl_pose[i][gt_ind], | |
| # 'gt_smpl_beta': tgt_smpl_beta[i][gt_ind], | |
| # 'gt_smpl_kp3d': tgt_smpl_kp3d[i][gt_ind], | |
| # 'gt_boxes': tgt_bbox[i][gt_ind], | |
| # 'gt_keypoints': tgt_keypoints[i][gt_ind], | |
| # 'image_idx': targets[i]['image_id'], | |
| # } | |
| # ) | |
| if self.nms_iou_threshold > 0: | |
| raise NotImplementedError | |
| item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)] | |
| # import pdb; pdb.set_trace() | |
| results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)] | |
| else: | |
| results = results | |
| return results | |
| class PostProcess_SMPLX_Multi(nn.Module): | |
| """ This module converts the model's output into the format expected by the coco api""" | |
| def __init__( | |
| self, | |
| num_select=100, | |
| nms_iou_threshold=-1, | |
| num_body_points=17, | |
| body_model= dict( | |
| type='smplx', | |
| keypoint_src='smplx', | |
| num_expression_coeffs=10, | |
| num_betas=10, | |
| gender='neutral', | |
| keypoint_dst='smplx_137', | |
| model_path='data/body_models/smplx', | |
| use_pca=False, | |
| use_face_contour=True, | |
| ), | |
| ) -> None: | |
| super().__init__() | |
| self.num_select = num_select | |
| self.nms_iou_threshold = nms_iou_threshold | |
| self.num_body_points=num_body_points | |
| # -1 for neutral; 0 for male; 1 for femal | |
| gender_body_model = {} | |
| gender_body_model[-1] = build_body_model(body_model) | |
| body_model['gender']='male' | |
| gender_body_model[0] = build_body_model(body_model) | |
| body_model['gender']='female' | |
| gender_body_model[1] = build_body_model(body_model) | |
| self.body_model = gender_body_model | |
| def forward(self, outputs, target_sizes, targets, data_batch_nc, not_to_xyxy=False, test=False, dataset = None): | |
| # import pdb; pdb.set_trace() | |
| batch_size = outputs['pred_keypoints'].shape[0] | |
| results = [] | |
| device = outputs['pred_keypoints'].device | |
| for body_model in self.body_model.values(): | |
| body_model.to(device) | |
| # test with instance num | |
| # num_select=data_batch_nc['joint_img'][0].shape[0] | |
| # num_select = self.num_select | |
| num_select = 1 | |
| out_logits, out_bbox, out_keypoints= \ | |
| outputs['pred_logits'], outputs['pred_boxes'], \ | |
| outputs['pred_keypoints'] | |
| out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \ | |
| outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \ | |
| outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts'] | |
| out_smpl_kp2d = [] | |
| for bs in range(batch_size): | |
| out_kp3d_i = out_smpl_kp3d[bs] | |
| out_cam_i = out_smpl_cam[bs] | |
| out_img_shape = data_batch_nc['img_shape'][bs].flip(-1)[None] | |
| # out_kp3d_i = out_kp3d_i - out_kp3d_i[:, [0]] | |
| out_kp2d_i = project_points_new( | |
| points_3d=out_kp3d_i, | |
| pred_cam=out_cam_i, | |
| focal_length=5000, | |
| camera_center=out_img_shape/2 | |
| ) | |
| out_smpl_kp2d.append(out_kp2d_i.detach().cpu().numpy()) | |
| out_smpl_kp2d = torch.tensor(out_smpl_kp2d).to(device) | |
| assert len(out_logits) == len(target_sizes) | |
| assert target_sizes.shape[1] == 2 | |
| prob = out_logits.sigmoid() | |
| topk_values, topk_indexes = \ | |
| torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1) | |
| scores = topk_values | |
| # bbox | |
| topk_boxes = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| if not_to_xyxy: | |
| boxes = out_bbox | |
| else: | |
| boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
| if test: | |
| assert not not_to_xyxy | |
| boxes[:,:,2:] = boxes[:,:,2:] - boxes[:,:,:2] | |
| # gather gt bbox | |
| boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| target_sizes = target_sizes.type_as(boxes) | |
| # from relative [0, 1] to absolute [0, height] coordinates | |
| img_h, img_w = target_sizes.unbind(1) | |
| scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
| boxes = boxes_norm * scale_fct[:, None, :] | |
| # smplx kp2d | |
| topk_smpl_kp2d = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| pred_smpl_kp2d = torch.gather( | |
| out_smpl_kp2d, 1, | |
| topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 137, 2)) | |
| # keypoints | |
| topk_keypoints = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| keypoints = torch.gather( | |
| out_keypoints, 1, | |
| topk_keypoints.unsqueeze(-1).repeat(1, 1, self.num_body_points*3)) | |
| Z_pred = keypoints[:, :, :(self.num_body_points * 2)] | |
| V_pred = keypoints[:, :, (self.num_body_points * 2):] | |
| img_h, img_w = target_sizes.unbind(1) | |
| Z_pred = Z_pred * torch.stack([img_w, img_h], dim=1).repeat(1, self.num_body_points)[:, None, :] | |
| keypoints_res = torch.zeros_like(keypoints) | |
| keypoints_res[..., 0::3] = Z_pred[..., 0::2] | |
| keypoints_res[..., 1::3] = Z_pred[..., 1::2] | |
| keypoints_res[..., 2::3] = V_pred[..., 0::1] | |
| # smpl out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d | |
| topk_smpl = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159)) | |
| smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3)) | |
| smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3)) | |
| smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3)) | |
| tgt_smpl_kp3d = data_batch_nc['joint_cam'] | |
| # tgt_smpl_kp3d_conf = data_batch_nc['joint_valid'] | |
| tgt_smpl_pose = data_batch_nc['smplx_pose'] | |
| tgt_smpl_beta = data_batch_nc['smplx_shape'] | |
| tgt_smpl_expr = data_batch_nc['smplx_expr'] | |
| tgt_keypoints = data_batch_nc['joint_img'] | |
| tgt_img_shape = data_batch_nc['img_shape'] | |
| # tgt_bbox_center = data_batch_nc['body_bbox_center'] | |
| # tgt_bbox_size = data_batch_nc['body_bbox_size'] | |
| tgt_bb2img_trans = data_batch_nc['bb2img_trans'] | |
| tgt_ann_idx = data_batch_nc['ann_idx'] | |
| pred_indice_list = [] | |
| gt_indice_list = [] | |
| tgt_verts = [] | |
| tgt_kp3d = [] | |
| tgt_bbox = [] | |
| for bbox_center, bbox_size, pose, \ | |
| beta, expr, gender, gt_kp2d, _, pred_kp2d, pred_kp3d, boxe, scale \ | |
| in zip( | |
| data_batch_nc['body_bbox_center'], | |
| data_batch_nc['body_bbox_size'], | |
| # data_batch_nc['bb2img_trans'], | |
| data_batch_nc['smplx_pose'], | |
| data_batch_nc['smplx_shape'], | |
| data_batch_nc['smplx_expr'], | |
| data_batch_nc['gender'], | |
| data_batch_nc['joint_img'], | |
| data_batch_nc['joint_cam'], | |
| # keypoints_res, smpl_kp3d, boxes, scale_fct, | |
| pred_smpl_kp2d, smpl_kp3d, boxes, scale_fct, | |
| ): | |
| # build smplx verts | |
| gt_verts = [] | |
| gt_kp3d = [] | |
| gt_bbox = [] | |
| gender_ = gender.cpu().numpy() | |
| for i, g in enumerate(gender_): | |
| gt_out = self.body_model[g]( | |
| betas=beta[i].reshape(-1, 10), | |
| global_orient=pose[i, :3].reshape(-1, 3).unsqueeze(1), | |
| body_pose=pose[i, 3:66].reshape(-1, 21 * 3), | |
| left_hand_pose=pose[i, 66:111].reshape(-1, 15 * 3), | |
| right_hand_pose=pose[i, 111:156].reshape(-1, 15 * 3), | |
| jaw_pose=pose[i, 156:159].reshape(-1, 3), | |
| leye_pose=torch.zeros_like(pose[i, 156:159]), | |
| reye_pose=torch.zeros_like(pose[i, 156:159]), | |
| expression=expr[i].reshape(-1, 10), | |
| ) | |
| gt_verts.append(gt_out['vertices'][0].detach().cpu().numpy()) | |
| gt_kp3d.append(gt_out['joints'][0].detach().cpu().numpy()) | |
| tgt_verts.append(gt_verts) | |
| tgt_kp3d.append(gt_kp3d) | |
| # bbox | |
| gt_bbox = torch.cat( | |
| [bbox_center - bbox_size / 2, bbox_size ], dim=-1) | |
| gt_bbox = gt_bbox * scale | |
| # xywh2xyxy | |
| gt_bbox[..., 2] = gt_bbox[..., 0] + gt_bbox[..., 2] | |
| gt_bbox[..., 3] = gt_bbox[..., 1] + gt_bbox[..., 3] | |
| tgt_bbox.append(gt_bbox[..., :4].float()) | |
| pred_bbox = boxe.clone() | |
| # box_iou = box_ops.box_iou(pred_bbox,gt_bbox)[0] | |
| cost_giou = -box_ops.generalized_box_iou(pred_bbox, gt_bbox) | |
| cost_bbox = torch.cdist( | |
| box_ops.box_xyxy_to_cxcywh(pred_bbox)/scale, | |
| box_ops.box_xyxy_to_cxcywh(gt_bbox)/scale, p=1) | |
| # smpl kp2d | |
| gt_kp2d_conf = gt_kp2d[:,:,2:3] | |
| gt_kp2d_ = (gt_kp2d[:, :, :2] * scale[:2]) /torch.tensor([12, 16]).to(device) | |
| gt_kp2d_body = gt_kp2d_[:, smpl_x.joint_part['body']] | |
| gt_kp2d_body_conf = gt_kp2d_conf[:, smpl_x.joint_part['body']] | |
| pred_kp2d_body = pred_kp2d[:, smpl_x.joint_part['body']] # smplx kps head | |
| # print(gt_kp2d_body.shape,gt_kp2d_body_conf.shape,pred_kp2d_body.shape,pred_kp2d.shape) | |
| # exit() | |
| # print(gt_kp2d_body_conf.shape) | |
| # exit() | |
| # gt_kp2d_body_conf, _ = convert_kps(gt_kp2d_conf,'smplx_137', 'coco', approximate=True) | |
| # gt_kp2d_body, _ = convert_kps(gt_kp2d_,'smplx_137', 'coco', approximate=True) | |
| # pred_kp2d_body, _ = convert_kps(pred_kp2d,'smplx_137', 'coco', approximate=True) | |
| # cost_keypoints = torch.abs( | |
| # (pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2]) | |
| # ).sum([-2,-1]) | |
| # print(dataset.__class__.__name__) | |
| if dataset.__class__.__name__ == 'UBody_MM': | |
| cost_keypoints = torch.abs( | |
| (pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])*gt_kp2d_body_conf[None] | |
| ).sum([-2,-1])/gt_kp2d_body_conf[None].sum() | |
| else: | |
| cost_keypoints = torch.abs( | |
| (pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2]) | |
| ).sum([-2,-1]) | |
| # smpl kp3d | |
| gt_kp3d_ = torch.tensor(np.array(gt_kp3d) - np.array(gt_kp3d)[:, [0]]).to(device) | |
| pred_kp3d_ = (pred_kp3d - pred_kp3d[:, [0]]) | |
| cost_kp3d = torch.abs((pred_kp3d_[:, None] - gt_kp3d_[None])).sum([-2,-1]) | |
| # 1. kps | |
| indice = linear_sum_assignment(cost_keypoints.cpu()) | |
| # 2. bbox giou | |
| # indice = linear_sum_assignment(cost_giou.cpu()) | |
| # 3. bbox | |
| # indice = linear_sum_assignment(cost_bbox.cpu()) | |
| # 4. all | |
| # indice = linear_sum_assignment( | |
| # 10* (cost_keypoints).cpu() + 5 * cost_bbox.cpu()) | |
| # 5. kp3d | |
| # indice = linear_sum_assignment(cost_kp3d.cpu()) | |
| pred_ind, gt_ind = indice | |
| pred_indice_list.append(pred_ind) | |
| gt_indice_list.append(gt_ind) | |
| pred_scores = torch.cat( | |
| [t[i] for t, i in zip(scores, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_labels = torch.cat( | |
| [t[i] for t, i in zip(labels, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_boxes = torch.cat( | |
| [t[i] for t, i in zip(boxes, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_keypoints = torch.cat( | |
| [t[i] for t, i in zip(keypoints_res, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_smpl_kp2d = [] | |
| pred_smpl_kp3d = [] | |
| pred_smpl_cam = [] | |
| img_wh_list = [] | |
| for i, img_wh in enumerate(tgt_img_shape): | |
| kp3d = smpl_kp3d[i][pred_indice_list[i]] | |
| cam = smpl_cam[i][pred_indice_list[i]] | |
| img_wh = img_wh.flip(-1)[None] | |
| kp2d = project_points_new( | |
| points_3d=kp3d, | |
| pred_cam=cam, | |
| focal_length=5000, | |
| camera_center=img_wh/2 | |
| ) | |
| num_instance = kp2d.shape[0] | |
| img_wh_list.append(img_wh.repeat(num_instance,1).cpu().numpy()) | |
| pred_smpl_kp2d.append(kp2d.detach().cpu().numpy()) | |
| pred_smpl_kp3d.append(kp3d.detach().cpu().numpy()) | |
| pred_smpl_cam.append(cam.detach().cpu().numpy()) | |
| # pred_smpl_cam = torch.cat( | |
| # [t[i] for t, i in zip(smpl_cam, pred_indice_list)] | |
| # ).detach().cpu().numpy() | |
| # pred_smpl_kp3d = torch.cat( | |
| # [t[i] for t, i in zip(smpl_kp3d, pred_indice_list)] | |
| # ) | |
| pred_smpl_pose = torch.cat( | |
| [t[i] for t, i in zip(smpl_pose, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_smpl_beta = torch.cat( | |
| [t[i] for t, i in zip(smpl_beta, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_smpl_expr = torch.cat( | |
| [t[i] for t, i in zip(smpl_expr, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_smpl_verts = torch.cat( | |
| [t[i] for t, i in zip(smpl_verts, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| # from pytorch3d.io import save_obj | |
| # for m_i,(mesh_out_i) in enumerate(smpl_verts[0].detach().cpu()): | |
| # save_obj('temp_smpl_%d.obj'%m_i,verts=(mesh_out_i),faces=torch.tensor([])) | |
| # for m_i,(mesh_out_i) in enumerate(pred_smpl_verts): | |
| # save_obj('temp_pred_%d.obj'%m_i,verts=torch.Tensor(mesh_out_i),faces=torch.tensor([])) | |
| # print(pred_indice_list) | |
| # exit() | |
| pred_smpl_kp2d = np.concatenate(pred_smpl_kp2d, 0) | |
| pred_smpl_kp3d = np.concatenate(pred_smpl_kp3d, 0) | |
| pred_smpl_cam = np.concatenate(pred_smpl_cam, 0) | |
| img_wh_list = np.concatenate(img_wh_list, 0) | |
| gt_smpl_kp3d = torch.cat(tgt_smpl_kp3d).detach().cpu().numpy() | |
| gt_smpl_pose = torch.cat(tgt_smpl_pose).detach().cpu().numpy() | |
| gt_smpl_beta = torch.cat(tgt_smpl_beta).detach().cpu().numpy() | |
| gt_boxes = torch.cat(tgt_bbox).detach().cpu().numpy() | |
| gt_smpl_expr = torch.cat(tgt_smpl_expr).detach().cpu().numpy() | |
| # gt_img_shape = torch.cat(tgt_img_shape).detach().cpu().numpy() | |
| gt_smpl_verts = np.concatenate( | |
| [np.array(t)[i] for t, i in zip(tgt_verts, gt_indice_list)], 0) | |
| gt_ann_idx = torch.cat([t.repeat(len(i)) for t, i in zip(tgt_ann_idx, gt_indice_list)],dim=0).cpu().numpy() | |
| gt_keypoints = torch.cat(tgt_keypoints).detach().cpu().numpy() | |
| # gt_img_shape = tgt_img_shape.detach().cpu().numpy() | |
| gt_bb2img_trans = torch.stack(tgt_bb2img_trans).detach().cpu().numpy() | |
| if 'image_id' in targets[i]: | |
| image_idx=(targets[i]['image_id'].detach().cpu().numpy()) | |
| smplx_root_pose = pred_smpl_pose[:,:3] | |
| smplx_body_pose = pred_smpl_pose[:,3:66] | |
| smplx_lhand_pose = pred_smpl_pose[:,66:111] | |
| smplx_rhand_pose = pred_smpl_pose[:,111:156] | |
| smplx_jaw_pose = pred_smpl_pose[:,156:] | |
| results.append({ | |
| 'scores': pred_scores, | |
| 'labels': pred_labels, | |
| 'boxes': pred_boxes, | |
| 'keypoints': pred_keypoints, | |
| 'smplx_root_pose': smplx_root_pose, | |
| 'smplx_body_pose': smplx_body_pose, | |
| 'smplx_lhand_pose': smplx_lhand_pose, | |
| 'smplx_rhand_pose': smplx_rhand_pose, | |
| 'smplx_jaw_pose': smplx_jaw_pose, | |
| 'smplx_shape': pred_smpl_beta, | |
| 'smplx_expr': pred_smpl_expr, | |
| 'cam_trans': pred_smpl_cam, | |
| 'smplx_mesh_cam': pred_smpl_verts, | |
| 'smplx_mesh_cam_target': gt_smpl_verts, | |
| 'gt_smpl_kp3d':gt_smpl_kp3d, | |
| 'smplx_joint_proj': pred_smpl_kp2d, | |
| # 'image_idx': image_idx, | |
| "img": data_batch_nc['img'].cpu().numpy(), | |
| 'bb2img_trans': gt_bb2img_trans, | |
| 'img_shape': img_wh_list, | |
| 'gt_ann_idx': gt_ann_idx | |
| }) | |
| if self.nms_iou_threshold > 0: | |
| raise NotImplementedError | |
| item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)] | |
| # import pdb; pdb.set_trace() | |
| results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)] | |
| else: | |
| results = results | |
| return results | |
| class PostProcess_SMPLX_Multi_Infer(nn.Module): | |
| """ This module converts the model's output into the format expected by the coco api""" | |
| def __init__( | |
| self, | |
| num_select=100, | |
| nms_iou_threshold=-1, | |
| num_body_points=17, | |
| body_model= dict( | |
| type='smplx', | |
| keypoint_src='smplx', | |
| num_expression_coeffs=10, | |
| num_betas=10, | |
| gender='neutral', | |
| keypoint_dst='smplx_137', | |
| model_path='data/body_models/smplx', | |
| use_pca=False, | |
| use_face_contour=True) | |
| ) -> None: | |
| super().__init__() | |
| self.num_select = num_select | |
| self.nms_iou_threshold = nms_iou_threshold | |
| self.num_body_points=num_body_points | |
| # -1 for neutral; 0 for male; 1 for femal | |
| gender_body_model = {} | |
| gender_body_model[-1] = build_body_model(body_model) | |
| body_model['gender']='male' | |
| gender_body_model[0] = build_body_model(body_model) | |
| body_model['gender']='female' | |
| gender_body_model[1] = build_body_model(body_model) | |
| self.body_model = gender_body_model | |
| def forward(self, outputs, target_sizes, targets, data_batch_nc, image_shape= None, not_to_xyxy=False, test=False): | |
| """ | |
| image_shape(target_sizes): input image shape | |
| """ | |
| # import pdb; pdb.set_trace() | |
| batch_size = outputs['pred_keypoints'].shape[0] | |
| results = [] | |
| device = outputs['pred_keypoints'].device | |
| # for body_model in self.body_model.values(): | |
| # body_model.to(device) | |
| pred_kp_coco = outputs['pred_keypoints'] | |
| num_select = self.num_select | |
| out_logits, out_bbox= outputs['pred_logits'], outputs['pred_boxes'] | |
| out_body_bbox, out_lhand_bbox, out_rhand_bbox, out_face_bbox = \ | |
| outputs['pred_boxes'], outputs['pred_lhand_boxes'], \ | |
| outputs['pred_rhand_boxes'], outputs['pred_face_boxes'] | |
| out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \ | |
| outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \ | |
| outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts'] | |
| out_smpl_kp2d = [] | |
| for bs in range(batch_size): | |
| out_kp3d_i = out_smpl_kp3d[bs] | |
| out_cam_i = out_smpl_cam[bs] | |
| out_img_shape = data_batch_nc['img_shape'][bs].flip(-1)[None] | |
| out_kp2d_i = project_points_new( | |
| points_3d=out_kp3d_i, | |
| pred_cam=out_cam_i, | |
| focal_length=5000, | |
| camera_center=out_img_shape/2 | |
| ) | |
| out_smpl_kp2d.append(out_kp2d_i.detach().cpu().numpy()) | |
| out_smpl_kp2d = torch.tensor(out_smpl_kp2d).to(device) | |
| # assert len(out_logits) == len(target_sizes) | |
| # assert target_sizes.shape[1] == 2 | |
| prob = out_logits.sigmoid() | |
| topk_values, topk_indexes = \ | |
| torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1) | |
| scores = topk_values | |
| # bbox | |
| topk_boxes = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| if not_to_xyxy: | |
| boxes = out_bbox | |
| else: | |
| boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
| out_body_bbox = box_ops.box_cxcywh_to_xyxy(out_body_bbox) | |
| out_lhand_bbox = box_ops.box_cxcywh_to_xyxy(out_lhand_bbox) | |
| out_rhand_bbox = box_ops.box_cxcywh_to_xyxy(out_rhand_bbox) | |
| out_face_bbox = box_ops.box_cxcywh_to_xyxy(out_face_bbox) | |
| # gather body bbox | |
| target_sizes = target_sizes.type_as(boxes) | |
| img_h, img_w = target_sizes.unbind(1) | |
| scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
| boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| boxes = boxes_norm * scale_fct[:, None, :] | |
| body_bbox_norm = torch.gather(out_body_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| body_boxes = body_bbox_norm * scale_fct[:, None, :] | |
| lhand_bbox_norm = torch.gather(out_lhand_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| lhand_boxes = lhand_bbox_norm * scale_fct[:, None, :] | |
| rhand_bbox_norm = torch.gather(out_rhand_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| rhand_boxes = rhand_bbox_norm * scale_fct[:, None, :] | |
| face_bbox_norm = torch.gather(out_face_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| face_boxes = face_bbox_norm * scale_fct[:, None, :] | |
| # from relative [0, 1] to absolute [0, height] coordinates | |
| # smplx kp2d | |
| topk_smpl_kp2d = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| pred_smpl_kp2d = torch.gather( | |
| out_smpl_kp2d, 1, | |
| topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 144, 2)) | |
| # pred_smpl_kp2d = np.concatenate(pred_smpl_kp2d, 0) | |
| pred_kp_coco = pred_kp_coco[..., 0:17*2].reshape(pred_kp_coco.shape[0], pred_kp_coco.shape[1], 17, 2) | |
| # pred_kp_coco_norm = torch.gather( | |
| # pred_kp_coco, 1, | |
| # topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 17, 2)) | |
| # pred_kp_coco = pred_kp_coco_norm * scale_fct[:, None, :2] | |
| # smpl param | |
| topk_smpl = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159)) | |
| smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3)) | |
| smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3)) | |
| smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3)) | |
| # smpl_verts = smpl_verts - smpl_kp3d[:,:, [0]] | |
| (s, tx, ty) = (smpl_cam[..., 0] + 1e-9), smpl_cam[..., 1], smpl_cam[..., 2] | |
| depth, dx, dy = 1./s, tx/s, ty/s | |
| transl = torch.stack([dx, dy, depth], -1) | |
| smplx_root_pose = smpl_pose[:, :, :3] | |
| smplx_body_pose = smpl_pose[:, :, 3:66] | |
| smplx_lhand_pose = smpl_pose[:, :, 66:111] | |
| smplx_rhand_pose = smpl_pose[:, :, 111:156] | |
| smplx_jaw_pose = smpl_pose[:, :, 156:] | |
| if 'ann_idx' in data_batch_nc: | |
| image_idx=[target.cpu().numpy()[0] for target in data_batch_nc['ann_idx']] | |
| for bs in range(batch_size): | |
| results.append({ | |
| 'scores': scores[bs], | |
| 'labels': labels[bs], | |
| 'keypoints_coco': pred_kp_coco[bs], | |
| 'smpl_kp3d': smpl_kp3d[bs], | |
| 'smplx_root_pose': smplx_root_pose[bs], | |
| 'smplx_body_pose': smplx_body_pose[bs], | |
| 'smplx_lhand_pose': smplx_lhand_pose[bs], | |
| 'smplx_rhand_pose': smplx_rhand_pose[bs], | |
| 'smplx_jaw_pose': smplx_jaw_pose[bs], | |
| 'smplx_shape': smpl_beta[bs], | |
| 'smplx_expr': smpl_expr[bs], | |
| 'smplx_joint_proj': pred_smpl_kp2d[bs], | |
| 'smpl_verts': smpl_verts[bs], | |
| 'image_idx': image_idx[bs], | |
| 'cam_trans': transl[bs], | |
| 'body_bbox': body_boxes[bs], | |
| 'lhand_bbox': lhand_boxes[bs], | |
| 'rhand_bbox': rhand_boxes[bs], | |
| 'face_bbox': face_boxes[bs], | |
| 'bb2img_trans': data_batch_nc['bb2img_trans'][bs], | |
| 'img2bb_trans': data_batch_nc['img2bb_trans'][bs], | |
| 'img': data_batch_nc['img'][bs], | |
| 'img_shape': data_batch_nc['img_shape'][bs] | |
| }) | |
| if self.nms_iou_threshold > 0: | |
| raise NotImplementedError | |
| item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)] | |
| # import pdb; pdb.set_trace() | |
| results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)] | |
| else: | |
| results = results | |
| return results | |
| class PostProcess_SMPLX_Multi_Box(nn.Module): | |
| """ This module converts the model's output into the format expected by the coco api""" | |
| def __init__( | |
| self, | |
| num_select=100, | |
| nms_iou_threshold=-1, | |
| num_body_points=17, | |
| body_model= dict( | |
| type='smplx', | |
| keypoint_src='smplx', | |
| num_expression_coeffs=10, | |
| num_betas=10, | |
| gender='neutral', | |
| keypoint_dst='smplx_137', | |
| model_path='data/body_models/smplx', | |
| use_pca=False, | |
| use_face_contour=True) | |
| ) -> None: | |
| super().__init__() | |
| self.num_select = num_select | |
| self.nms_iou_threshold = nms_iou_threshold | |
| self.num_body_points=num_body_points | |
| # -1 for neutral; 0 for male; 1 for femal | |
| gender_body_model = {} | |
| gender_body_model[-1] = build_body_model(body_model) | |
| body_model['gender']='male' | |
| gender_body_model[0] = build_body_model(body_model) | |
| body_model['gender']='female' | |
| gender_body_model[1] = build_body_model(body_model) | |
| self.body_model = gender_body_model | |
| def forward(self, outputs, target_sizes, targets, data_batch_nc, not_to_xyxy=False, test=False): | |
| # import pdb; pdb.set_trace() | |
| batch_size = outputs['pred_smpl_beta'].shape[0] | |
| results = [] | |
| device = outputs['pred_smpl_beta'].device | |
| for body_model in self.body_model.values(): | |
| body_model.to(device) | |
| # test with instance num | |
| # num_select=data_batch_nc['joint_img'][0].shape[0] | |
| num_select = self.num_select | |
| out_logits, out_bbox= outputs['pred_logits'], outputs['pred_boxes'] | |
| out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \ | |
| outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \ | |
| outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts'] | |
| out_smpl_kp2d = [] | |
| for bs in range(batch_size): | |
| out_kp3d_i = out_smpl_kp3d[bs] | |
| out_cam_i = out_smpl_cam[bs] | |
| out_img_shape = data_batch_nc['img_shape'][bs].flip(-1)[None] | |
| # out_kp3d_i = out_kp3d_i - out_kp3d_i[:, [0]] | |
| out_kp2d_i = project_points_new( | |
| points_3d=out_kp3d_i, | |
| pred_cam=out_cam_i, | |
| focal_length=5000, | |
| camera_center=out_img_shape/2 | |
| ) | |
| out_smpl_kp2d.append(out_kp2d_i.detach().cpu().numpy()) | |
| out_smpl_kp2d = torch.tensor(out_smpl_kp2d).to(device) | |
| assert len(out_logits) == len(target_sizes) | |
| assert target_sizes.shape[1] == 2 | |
| prob = out_logits.sigmoid() | |
| topk_values, topk_indexes = \ | |
| torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1) | |
| scores = topk_values | |
| # bbox | |
| topk_boxes = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| if not_to_xyxy: | |
| boxes = out_bbox | |
| else: | |
| boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
| if test: | |
| assert not not_to_xyxy | |
| boxes[:,:,2:] = boxes[:,:,2:] - boxes[:,:,:2] | |
| # gather gt bbox | |
| boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| target_sizes = target_sizes.type_as(boxes) | |
| # from relative [0, 1] to absolute [0, height] coordinates | |
| img_h, img_w = target_sizes.unbind(1) | |
| scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
| boxes = boxes_norm * scale_fct[:, None, :] | |
| # smplx kp2d | |
| topk_smpl_kp2d = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| pred_smpl_kp2d = torch.gather( | |
| out_smpl_kp2d, 1, | |
| topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 137, 2)) | |
| # smpl out_smpl_pose, out_smpl_beta, out_smpl_cam, out_smpl_kp3d | |
| topk_smpl = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159)) | |
| smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3)) | |
| smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3)) | |
| smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3)) | |
| tgt_smpl_kp3d = data_batch_nc['joint_cam'] | |
| tgt_smpl_pose = data_batch_nc['smplx_pose'] | |
| tgt_smpl_beta = data_batch_nc['smplx_shape'] | |
| tgt_smpl_expr = data_batch_nc['smplx_expr'] | |
| tgt_keypoints = data_batch_nc['joint_img'] | |
| tgt_img_shape = data_batch_nc['img_shape'] | |
| tgt_bb2img_trans = data_batch_nc['bb2img_trans'] | |
| tgt_ann_idx = data_batch_nc['ann_idx'] | |
| pred_indice_list = [] | |
| gt_indice_list = [] | |
| tgt_verts = [] | |
| tgt_kp3d = [] | |
| tgt_bbox = [] | |
| for bbox_center, bbox_size, pose, \ | |
| beta, expr, gender, gt_kp2d, _, pred_kp2d, pred_kp3d, boxe, scale \ | |
| in zip( | |
| data_batch_nc['body_bbox_center'], | |
| data_batch_nc['body_bbox_size'], | |
| data_batch_nc['smplx_pose'], | |
| data_batch_nc['smplx_shape'], | |
| data_batch_nc['smplx_expr'], | |
| data_batch_nc['gender'], | |
| data_batch_nc['joint_img'], | |
| data_batch_nc['joint_cam'], | |
| pred_smpl_kp2d, smpl_kp3d, boxes, scale_fct, | |
| ): | |
| # build smplx verts | |
| gt_verts = [] | |
| gt_kp3d = [] | |
| gt_bbox = [] | |
| gender_ = gender.cpu().numpy() | |
| for i, g in enumerate(gender_): | |
| gt_out = self.body_model[g]( | |
| betas=beta[i].reshape(-1, 10), | |
| global_orient=pose[i, :3].reshape(-1, 3).unsqueeze(1), | |
| body_pose=pose[i, 3:66].reshape(-1, 21 * 3), | |
| left_hand_pose=pose[i, 66:111].reshape(-1, 15 * 3), | |
| right_hand_pose=pose[i, 111:156].reshape(-1, 15 * 3), | |
| jaw_pose=pose[i, 156:159].reshape(-1, 3), | |
| leye_pose=torch.zeros_like(pose[i, 156:159]), | |
| reye_pose=torch.zeros_like(pose[i, 156:159]), | |
| expression=expr[i].reshape(-1, 10), | |
| ) | |
| gt_verts.append(gt_out['vertices'][0].detach().cpu().numpy()) | |
| gt_kp3d.append(gt_out['joints'][0].detach().cpu().numpy()) | |
| tgt_verts.append(gt_verts) | |
| tgt_kp3d.append(gt_kp3d) | |
| # bbox | |
| gt_bbox = torch.cat( | |
| [bbox_center - bbox_size / 2, bbox_size ], dim=-1) | |
| gt_bbox = gt_bbox * scale | |
| # xywh2xyxy | |
| gt_bbox[..., 2] = gt_bbox[..., 0] + gt_bbox[..., 2] | |
| gt_bbox[..., 3] = gt_bbox[..., 1] + gt_bbox[..., 3] | |
| tgt_bbox.append(gt_bbox[..., :4].float()) | |
| pred_bbox = boxe.clone() | |
| # box_iou = box_ops.box_iou(pred_bbox,gt_bbox)[0] | |
| cost_giou = -box_ops.generalized_box_iou(pred_bbox, gt_bbox) | |
| cost_bbox = torch.cdist( | |
| box_ops.box_xyxy_to_cxcywh(pred_bbox)/scale, | |
| box_ops.box_xyxy_to_cxcywh(gt_bbox)/scale, p=1) | |
| # smpl kp2d | |
| gt_kp2d_conf = gt_kp2d[:,:,2:3] | |
| gt_kp2d_ = (gt_kp2d[:, :, :2] * scale[:2]) /torch.tensor([12, 16]).to(device) | |
| # gt_kp2d_conf, _ = convert_kps(gt_kp2d_conf,'smplx_137', 'coco', approximate=True) | |
| # cost_keypoints = torch.abs( | |
| # (pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])*gt_kp2d_conf[None] | |
| # ).sum([-2,-1])/gt_kp2d_conf[None].sum() | |
| gt_kp2d_body, _ = convert_kps(gt_kp2d_,'smplx_137', 'coco', approximate=True) | |
| pred_kp2d_body, _ = convert_kps(pred_kp2d,'smplx_137', 'coco', approximate=True) | |
| cost_keypoints = torch.abs( | |
| (pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2]) | |
| ).sum([-2,-1]) | |
| # cost_keypoints = torch.abs( | |
| # (pred_kp2d_body[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])*gt_kp2d_body_conf[None] | |
| # ).sum([-2,-1])/gt_kp2d_body_conf[None].sum() | |
| # coco kp2d | |
| # gt_kp2d_conf, _ = convert_kps(gt_kp2d_conf,'smplx_137', 'coco', approximate=True) | |
| # keypoints_coco = Z_pred.reshape(num_select, 17,2) | |
| # ubody | |
| # cost_keypoints_coco = torch.abs( | |
| # (keypoints_coco[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2])*gt_kp2d_conf[None] | |
| # ).sum([-2,-1])/gt_kp2d_conf[None].sum() | |
| # others | |
| # cost_keypoints_coco = torch.abs( | |
| # (keypoints_coco[:, None]/scale[:2] - gt_kp2d_body[None]/scale[:2]) | |
| # ).sum([-2,-1]) | |
| # smpl kp3d | |
| gt_kp3d_ = torch.tensor(np.array(gt_kp3d) - np.array(gt_kp3d)[:, [0]]).to(device) | |
| pred_kp3d_ = (pred_kp3d - pred_kp3d[:, [0]]) | |
| cost_kp3d = torch.abs((pred_kp3d_[:, None] - gt_kp3d_[None])).sum([-2,-1]) | |
| # 1. kps | |
| indice = linear_sum_assignment(cost_keypoints.cpu()) | |
| # 2. bbox giou | |
| # indice = linear_sum_assignment(cost_giou.cpu()) | |
| # 3. bbox | |
| # indice = linear_sum_assignment(cost_bbox.cpu()) | |
| # 4. all | |
| # indice = linear_sum_assignment( | |
| # 10* (cost_keypoints).cpu() + 5 * cost_bbox.cpu()) | |
| # 5. kp3d | |
| # indice = linear_sum_assignment(cost_kp3d.cpu()) | |
| # 5. kp2d coco | |
| # indice = linear_sum_assignment(cost_keypoints_coco.cpu()) | |
| pred_ind, gt_ind = indice | |
| pred_indice_list.append(pred_ind) | |
| gt_indice_list.append(gt_ind) | |
| pred_scores = torch.cat( | |
| [t[i] for t, i in zip(scores, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_labels = torch.cat( | |
| [t[i] for t, i in zip(labels, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_boxes = torch.cat( | |
| [t[i] for t, i in zip(boxes, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| # pred_keypoints = torch.cat( | |
| # [t[i] for t, i in zip(keypoints_res, pred_indice_list)] | |
| # ).detach().cpu().numpy() | |
| pred_smpl_kp2d = [] | |
| pred_smpl_kp3d = [] | |
| pred_smpl_cam = [] | |
| img_wh_list = [] | |
| for i, img_wh in enumerate(tgt_img_shape): | |
| kp3d = smpl_kp3d[i][pred_indice_list[i]] | |
| cam = smpl_cam[i][pred_indice_list[i]] | |
| img_wh = img_wh.flip(-1)[None] | |
| kp2d = project_points_new( | |
| points_3d=kp3d, | |
| pred_cam=cam, | |
| focal_length=5000, | |
| camera_center=img_wh/2 | |
| ) | |
| num_instance = kp2d.shape[0] | |
| img_wh_list.append(img_wh.repeat(num_instance,1).cpu().numpy()) | |
| pred_smpl_kp2d.append(kp2d.detach().cpu().numpy()) | |
| pred_smpl_kp3d.append(kp3d.detach().cpu().numpy()) | |
| pred_smpl_cam.append(cam.detach().cpu().numpy()) | |
| # pred_smpl_cam = torch.cat( | |
| # [t[i] for t, i in zip(smpl_cam, pred_indice_list)] | |
| # ).detach().cpu().numpy() | |
| # pred_smpl_kp3d = torch.cat( | |
| # [t[i] for t, i in zip(smpl_kp3d, pred_indice_list)] | |
| # ) | |
| pred_smpl_pose = torch.cat( | |
| [t[i] for t, i in zip(smpl_pose, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_smpl_beta = torch.cat( | |
| [t[i] for t, i in zip(smpl_beta, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_smpl_expr = torch.cat( | |
| [t[i] for t, i in zip(smpl_expr, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_smpl_verts = torch.cat( | |
| [t[i] for t, i in zip(smpl_verts, pred_indice_list)] | |
| ).detach().cpu().numpy() | |
| pred_smpl_kp2d = np.concatenate(pred_smpl_kp2d, 0) | |
| pred_smpl_kp3d = np.concatenate(pred_smpl_kp3d, 0) | |
| pred_smpl_cam = np.concatenate(pred_smpl_cam, 0) | |
| img_wh_list = np.concatenate(img_wh_list, 0) | |
| gt_smpl_kp3d = torch.cat(tgt_smpl_kp3d).detach().cpu().numpy() | |
| gt_smpl_pose = torch.cat(tgt_smpl_pose).detach().cpu().numpy() | |
| gt_smpl_beta = torch.cat(tgt_smpl_beta).detach().cpu().numpy() | |
| gt_boxes = torch.cat(tgt_bbox).detach().cpu().numpy() | |
| gt_smpl_expr = torch.cat(tgt_smpl_expr).detach().cpu().numpy() | |
| # gt_img_shape = torch.cat(tgt_img_shape).detach().cpu().numpy() | |
| gt_smpl_verts = np.concatenate( | |
| [np.array(t)[i] for t, i in zip(tgt_verts, gt_indice_list)], 0) | |
| gt_ann_idx = torch.cat([t.repeat(len(i)) for t, i in zip(tgt_ann_idx, gt_indice_list)],dim=0).cpu().numpy() | |
| gt_keypoints = torch.cat(tgt_keypoints).detach().cpu().numpy() | |
| # gt_img_shape = tgt_img_shape.detach().cpu().numpy() | |
| gt_bb2img_trans = torch.stack(tgt_bb2img_trans).detach().cpu().numpy() | |
| if 'image_id' in targets[i]: | |
| image_idx=(targets[i]['image_id'].detach().cpu().numpy()) | |
| smplx_root_pose = pred_smpl_pose[:,:3] | |
| smplx_body_pose = pred_smpl_pose[:,3:66] | |
| smplx_lhand_pose = pred_smpl_pose[:,66:111] | |
| smplx_rhand_pose = pred_smpl_pose[:,111:156] | |
| smplx_jaw_pose = pred_smpl_pose[:,156:] | |
| results.append({ | |
| 'scores': pred_scores, | |
| 'labels': pred_labels, | |
| 'boxes': pred_boxes, | |
| # 'keypoints': pred_keypoints, | |
| 'smplx_root_pose': smplx_root_pose, | |
| 'smplx_body_pose': smplx_body_pose, | |
| 'smplx_lhand_pose': smplx_lhand_pose, | |
| 'smplx_rhand_pose': smplx_rhand_pose, | |
| 'smplx_jaw_pose': smplx_jaw_pose, | |
| 'smplx_shape': pred_smpl_beta, | |
| 'smplx_expr': pred_smpl_expr, | |
| 'cam_trans': pred_smpl_cam, | |
| 'smplx_mesh_cam': pred_smpl_verts, | |
| 'smplx_mesh_cam_target': gt_smpl_verts, | |
| 'gt_smpl_kp3d':gt_smpl_kp3d, | |
| 'smplx_joint_proj': pred_smpl_kp2d, | |
| # 'image_idx': image_idx, | |
| "img": data_batch_nc['img'].cpu().numpy(), | |
| 'bb2img_trans': gt_bb2img_trans, | |
| 'img_shape': img_wh_list, | |
| 'gt_ann_idx': gt_ann_idx | |
| }) | |
| if self.nms_iou_threshold > 0: | |
| raise NotImplementedError | |
| item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)] | |
| # import pdb; pdb.set_trace() | |
| results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)] | |
| else: | |
| results = results | |
| return results | |
| class PostProcess_SMPLX_Multi_Infer_Box(nn.Module): | |
| """ This module converts the model's output into the format expected by the coco api""" | |
| def __init__( | |
| self, | |
| num_select=100, | |
| nms_iou_threshold=-1, | |
| num_body_points=17, | |
| body_model= dict( | |
| type='smplx', | |
| keypoint_src='smplx', | |
| num_expression_coeffs=10, | |
| num_betas=10, | |
| gender='neutral', | |
| keypoint_dst='smplx_137', | |
| model_path='data/body_models/smplx', | |
| use_pca=False, | |
| use_face_contour=True) | |
| ) -> None: | |
| super().__init__() | |
| self.num_select = num_select | |
| self.nms_iou_threshold = nms_iou_threshold | |
| self.num_body_points=num_body_points | |
| # -1 for neutral; 0 for male; 1 for femal | |
| gender_body_model = {} | |
| gender_body_model[-1] = build_body_model(body_model) | |
| body_model['gender']='male' | |
| gender_body_model[0] = build_body_model(body_model) | |
| body_model['gender']='female' | |
| gender_body_model[1] = build_body_model(body_model) | |
| self.body_model = gender_body_model | |
| def forward(self, outputs, target_sizes, targets, data_batch_nc, image_shape= None, not_to_xyxy=False, test=False): | |
| """ | |
| image_shape(target_sizes): input image shape | |
| """ | |
| batch_size = outputs['pred_smpl_beta'].shape[0] | |
| results = [] | |
| device = outputs['pred_smpl_beta'].device | |
| num_select = self.num_select | |
| out_logits, out_bbox= outputs['pred_logits'], outputs['pred_boxes'] | |
| out_body_bbox, out_lhand_bbox, out_rhand_bbox, out_face_bbox = \ | |
| outputs['pred_boxes'], outputs['pred_lhand_boxes'], \ | |
| outputs['pred_rhand_boxes'], outputs['pred_face_boxes'] | |
| out_smpl_pose, out_smpl_beta, out_smpl_expr, out_smpl_cam, out_smpl_kp3d, out_smpl_verts = \ | |
| outputs['pred_smpl_fullpose'], outputs['pred_smpl_beta'], outputs['pred_smpl_expr'], \ | |
| outputs['pred_smpl_cam'], outputs['pred_smpl_kp3d'], outputs['pred_smpl_verts'] | |
| out_smpl_kp2d = [] | |
| for bs in range(batch_size): | |
| out_kp3d_i = out_smpl_kp3d[bs] | |
| out_cam_i = out_smpl_cam[bs] | |
| out_img_shape = data_batch_nc['img_shape'][bs].flip(-1)[None] | |
| out_kp2d_i = project_points_new( | |
| points_3d=out_kp3d_i, | |
| pred_cam=out_cam_i, | |
| focal_length=5000, | |
| camera_center=out_img_shape/2 | |
| ) | |
| out_smpl_kp2d.append(out_kp2d_i.detach().cpu().numpy()) | |
| out_smpl_kp2d = torch.tensor(out_smpl_kp2d).to(device) | |
| prob = out_logits.sigmoid() | |
| topk_values, topk_indexes = \ | |
| torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1) | |
| scores = topk_values | |
| # bbox | |
| topk_boxes = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| if not_to_xyxy: | |
| boxes = out_bbox | |
| else: | |
| boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
| out_body_bbox = box_ops.box_cxcywh_to_xyxy(out_body_bbox) | |
| out_lhand_bbox = box_ops.box_cxcywh_to_xyxy(out_lhand_bbox) | |
| out_rhand_bbox = box_ops.box_cxcywh_to_xyxy(out_rhand_bbox) | |
| out_face_bbox = box_ops.box_cxcywh_to_xyxy(out_face_bbox) | |
| # gather body bbox | |
| target_sizes = target_sizes.type_as(boxes) | |
| img_h, img_w = target_sizes.unbind(1) | |
| scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
| boxes_norm = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| boxes = boxes_norm * scale_fct[:, None, :] | |
| body_bbox_norm = torch.gather(out_body_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| body_boxes = body_bbox_norm * scale_fct[:, None, :] | |
| lhand_bbox_norm = torch.gather(out_lhand_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| lhand_boxes = lhand_bbox_norm * scale_fct[:, None, :] | |
| rhand_bbox_norm = torch.gather(out_rhand_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| rhand_boxes = rhand_bbox_norm * scale_fct[:, None, :] | |
| face_bbox_norm = torch.gather(out_face_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
| face_boxes = face_bbox_norm * scale_fct[:, None, :] | |
| # from relative [0, 1] to absolute [0, height] coordinates | |
| # smplx kp2d | |
| topk_smpl_kp2d = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| pred_smpl_kp2d = torch.gather( | |
| out_smpl_kp2d, 1, | |
| topk_smpl_kp2d.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 144, 2)) | |
| # smpl param | |
| topk_smpl = topk_indexes // out_logits.shape[2] | |
| labels = topk_indexes % out_logits.shape[2] | |
| smpl_pose = torch.gather(out_smpl_pose, 1, topk_smpl[:,:,None].repeat(1, 1, 159)) | |
| smpl_beta = torch.gather(out_smpl_beta, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_expr = torch.gather(out_smpl_expr, 1, topk_smpl[:,:,None].repeat(1, 1, 10)) | |
| smpl_cam = torch.gather(out_smpl_cam, 1, topk_smpl[:,:,None].repeat(1, 1, 3)) | |
| smpl_kp3d = torch.gather(out_smpl_kp3d, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_kp3d.shape[-2],3)) | |
| smpl_verts = torch.gather(out_smpl_verts, 1, topk_smpl[:,:,None, None].repeat(1, 1, out_smpl_verts.shape[-2],3)) | |
| # smpl_verts = smpl_verts - smpl_kp3d[:,:, [0]] | |
| (s, tx, ty) = (smpl_cam[..., 0] + 1e-9), smpl_cam[..., 1], smpl_cam[..., 2] | |
| depth, dx, dy = 1./s, tx/s, ty/s | |
| transl = torch.stack([dx, dy, depth], -1) | |
| smplx_root_pose = smpl_pose[:, :, :3] | |
| smplx_body_pose = smpl_pose[:, :, 3:66] | |
| smplx_lhand_pose = smpl_pose[:, :, 66:111] | |
| smplx_rhand_pose = smpl_pose[:, :, 111:156] | |
| smplx_jaw_pose = smpl_pose[:, :, 156:] | |
| if 'ann_idx' in data_batch_nc: | |
| image_idx=[target.cpu().numpy()[0] for target in data_batch_nc['ann_idx']] | |
| for bs in range(batch_size): | |
| results.append({ | |
| 'scores': scores[bs], | |
| 'labels': labels[bs], | |
| 'smpl_kp3d': smpl_kp3d[bs], | |
| 'smplx_root_pose': smplx_root_pose[bs], | |
| 'smplx_body_pose': smplx_body_pose[bs], | |
| 'smplx_lhand_pose': smplx_lhand_pose[bs], | |
| 'smplx_rhand_pose': smplx_rhand_pose[bs], | |
| 'smplx_jaw_pose': smplx_jaw_pose[bs], | |
| 'smplx_shape': smpl_beta[bs], | |
| 'smplx_expr': smpl_expr[bs], | |
| 'smplx_joint_proj': pred_smpl_kp2d[bs], | |
| 'smpl_verts': smpl_verts[bs], | |
| 'image_idx': image_idx[bs], | |
| 'cam_trans': transl[bs], | |
| 'body_bbox': body_boxes[bs], | |
| 'lhand_bbox': lhand_boxes[bs], | |
| 'rhand_bbox': rhand_boxes[bs], | |
| 'face_bbox': face_boxes[bs], | |
| 'bb2img_trans': data_batch_nc['bb2img_trans'][bs], | |
| 'img2bb_trans': data_batch_nc['img2bb_trans'][bs], | |
| 'img': data_batch_nc['img'][bs], | |
| 'img_shape': data_batch_nc['img_shape'][bs] | |
| }) | |
| if self.nms_iou_threshold > 0: | |
| raise NotImplementedError | |
| item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b,s in zip(boxes, scores)] | |
| # import pdb; pdb.set_trace() | |
| results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in zip(scores, labels, boxes, item_indices)] | |
| else: | |
| results = results | |
| return results | |