AiOS / detrsmpl /data /datasets /human_video_dataset.py
ttxskk
update
d7e58f0
import copy
from typing import Optional, Union
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from skimage.util.shape import view_as_windows
from .builder import DATASETS
from .human_image_dataset import HumanImageDataset
def get_vid_name(image_path: str):
"""Get base_dir of the given path."""
content = image_path.split('/')
vid_name = '/'.join(content[:-1])
return vid_name
def split_into_chunks(data_infos: list, seq_len: int, stride: int,
test_mode: bool, only_vid_name: bool):
"""Split annotations into chunks.
Adapted from https://github.com/mkocabas/VIBE
Args:
data_infos (list): parsed annotations.
seq_len (int): the length of each chunk.
stride (int): the interval between chunks.
test_mode (bool): if test_mode is true, then an additional chunk
will be added to cover all frames. Otherwise, last few frames
will be dropped.
only_vid_name (bool): if only_vid_name is true, image_path only
contains the video name. Otherwise, image_path contains both
video_name and frame index.
Return:
list:
shape: [N, 4]. Each chunk contains four parameters: start_frame,
end_frame, valid_start_frame, valid_end_frame. The last two
parameters are used to suppress redundant frames.
"""
vid_names = []
for image_path in data_infos:
if only_vid_name:
vid_name = image_path
else:
vid_name = get_vid_name(image_path)
vid_names.append(vid_name)
vid_names = np.array(vid_names)
video_start_end_indices = []
video_names, group = np.unique(vid_names, return_index=True)
perm = np.argsort(group)
video_names, group = video_names[perm], group[perm]
indices = np.split(np.arange(0, vid_names.shape[0]), group[1:])
for idx in range(len(video_names)):
indexes = indices[idx]
if indexes.shape[0] < seq_len:
continue
chunks = view_as_windows(indexes, (seq_len, ), step=stride)
start_finish = chunks[:, (0, -1, 0, -1)].tolist()
video_start_end_indices += start_finish
if chunks[-1][-1] < indexes[-1] and test_mode:
start_frame = indexes[-1] - seq_len + 1
end_frame = indexes[-1]
valid_start_frame = chunks[-1][-1] + 1
valid_end_frame = indexes[-1]
extra_start_finish = [[
start_frame, end_frame, valid_start_frame, valid_end_frame
]]
video_start_end_indices += extra_start_finish
return video_start_end_indices
@DATASETS.register_module()
class HumanVideoDataset(HumanImageDataset):
"""Human Video Dataset.
Args:
data_prefix (str): the prefix of data path.
pipeline (list): a list of dict, where each element represents
a operation defined in `mmhuman3d.datasets.pipelines`.
dataset_name (str | None): the name of dataset. It is used to
identify the type of evaluation metric. Default: None.
seq_len (int, optional): the length of input sequence. Default: 16.
overlap (float, optional): the overlap between different sequences.
Default: 0
only_vid_name (bool, optional): the format of image_path.
If only_vid_name is true, image_path only contains the video
name. Otherwise, image_path contains both video_name and frame
index.
body_model (dict | None, optional): the config for body model,
which will be used to generate meshes and keypoints.
Default: None.
ann_file (str | None, optional): the annotation file. When ann_file
is str, the subclass is expected to read from the ann_file. When
ann_file is None, the subclass is expected to read according
to data_prefix.
convention (str, optional): keypoints convention. Keypoints will be
converted from "human_data" to the given one.
Default: "human_data"
test_mode (bool, optional): in train mode or test mode. Default: False.
"""
def __init__(self,
data_prefix: str,
pipeline: list,
dataset_name: str,
seq_len: Optional[int] = 16,
overlap: Optional[float] = 0.,
only_vid_name: Optional[bool] = False,
body_model: Optional[Union[dict, None]] = None,
ann_file: Optional[Union[str, None]] = None,
convention: Optional[str] = 'human_data',
test_mode: Optional[bool] = False):
super(HumanVideoDataset, self).__init__(data_prefix=data_prefix,
pipeline=pipeline,
dataset_name=dataset_name,
body_model=body_model,
convention=convention,
ann_file=ann_file,
test_mode=test_mode)
self.seq_len = seq_len
self.stride = int(seq_len * (1 - overlap))
self.vid_indices = split_into_chunks(self.human_data['image_path'],
self.seq_len, self.stride,
test_mode, only_vid_name)
self.vid_indices = np.array(self.vid_indices)
def __len__(self):
return len(self.vid_indices)
def prepare_data(self, idx: int):
"""Prepare data for each chunk.
Step 1: get annotation from each frame. Step 2: add metas of each
chunk.
"""
start_idx, end_idx = self.vid_indices[idx][:2]
batch_results = []
image_path = []
for frame_idx in range(start_idx, end_idx + 1):
frame_results = copy.deepcopy(self.prepare_raw_data(frame_idx))
image_path.append(frame_results.pop('image_path'))
if 'features' in self.human_data:
frame_results['features'] = \
copy.deepcopy(self.human_data['features'][frame_idx])
frame_results = self.pipeline(frame_results)
batch_results.append(frame_results)
video_results = {}
for key in batch_results[0].keys():
batch_anno = []
for item in batch_results:
batch_anno.append(item[key])
if isinstance(batch_anno[0], torch.Tensor):
batch_anno = torch.stack(batch_anno, dim=0)
video_results[key] = batch_anno
img_metas = {
'frame_idx': self.vid_indices[idx],
'image_path': image_path
}
video_results['img_metas'] = DC(img_metas, cpu_only=True)
return video_results