Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import random | |
import numpy as np | |
from .utils import align_frames | |
class FrameRefExtractAnnotator: | |
para_dict = {} | |
def __init__(self, cfg, device=None): | |
# first / last / firstlast / random | |
self.ref_cfg = cfg.get('REF_CFG', [{"mode": "first", "proba": 0.1}, | |
{"mode": "last", "proba": 0.1}, | |
{"mode": "firstlast", "proba": 0.1}, | |
{"mode": "random", "proba": 0.1}]) | |
self.ref_num = cfg.get('REF_NUM', 1) | |
self.ref_color = cfg.get('REF_COLOR', 127.5) | |
self.return_dict = cfg.get('RETURN_DICT', True) | |
self.return_mask = cfg.get('RETURN_MASK', True) | |
def forward(self, frames, ref_cfg=None, ref_num=None, return_mask=None, return_dict=None): | |
return_mask = return_mask if return_mask is not None else self.return_mask | |
return_dict = return_dict if return_dict is not None else self.return_dict | |
ref_cfg = ref_cfg if ref_cfg is not None else self.ref_cfg | |
ref_cfg = [ref_cfg] if not isinstance(ref_cfg, list) else ref_cfg | |
probas = [item['proba'] if 'proba' in item else 1.0 / len(ref_cfg) for item in ref_cfg] | |
sel_ref_cfg = random.choices(ref_cfg, weights=probas, k=1)[0] | |
mode = sel_ref_cfg['mode'] if 'mode' in sel_ref_cfg else 'original' | |
ref_num = int(ref_num) if ref_num is not None else self.ref_num | |
frame_num = len(frames) | |
frame_num_range = list(range(frame_num)) | |
if mode == "first": | |
sel_idx = frame_num_range[:ref_num] | |
elif mode == "last": | |
sel_idx = frame_num_range[-ref_num:] | |
elif mode == "firstlast": | |
sel_idx = frame_num_range[:ref_num] + frame_num_range[-ref_num:] | |
elif mode == "random": | |
sel_idx = random.sample(frame_num_range, ref_num) | |
else: | |
raise NotImplementedError | |
out_frames, out_masks = [], [] | |
for i in range(frame_num): | |
if i in sel_idx: | |
out_frame = frames[i] | |
out_mask = np.zeros_like(frames[i][:, :, 0]) | |
else: | |
out_frame = np.ones_like(frames[i]) * self.ref_color | |
out_mask = np.ones_like(frames[i][:, :, 0]) * 255 | |
out_frames.append(out_frame) | |
out_masks.append(out_mask) | |
if return_dict: | |
ret_data = {"frames": out_frames} | |
if return_mask: | |
ret_data['masks'] = out_masks | |
return ret_data | |
else: | |
if return_mask: | |
return out_frames, out_masks | |
else: | |
return out_frames | |
class FrameRefExpandAnnotator: | |
para_dict = {} | |
def __init__(self, cfg, device=None): | |
# first / last / firstlast | |
self.ref_color = cfg.get('REF_COLOR', 127.5) | |
self.return_mask = cfg.get('RETURN_MASK', True) | |
self.return_dict = cfg.get('RETURN_DICT', True) | |
self.mode = cfg.get('MODE', "firstframe") | |
assert self.mode in ["firstframe", "lastframe", "firstlastframe", "firstclip", "lastclip", "firstlastclip", "all"] | |
def forward(self, image=None, image_2=None, frames=None, frames_2=None, mode=None, expand_num=None, return_mask=None, return_dict=None): | |
mode = mode if mode is not None else self.mode | |
return_mask = return_mask if return_mask is not None else self.return_mask | |
return_dict = return_dict if return_dict is not None else self.return_dict | |
if 'frame' in mode: | |
frames = [image] if image is not None and not isinstance(frames, list) else image | |
frames_2 = [image_2] if image_2 is not None and not isinstance(image_2, list) else image_2 | |
expand_frames = [np.ones_like(frames[0]) * self.ref_color] * expand_num | |
expand_masks = [np.ones_like(frames[0][:, :, 0]) * 255] * expand_num | |
source_frames = frames | |
source_masks = [np.zeros_like(frames[0][:, :, 0])] * len(frames) | |
if mode in ["firstframe", "firstclip"]: | |
out_frames = source_frames + expand_frames | |
out_masks = source_masks + expand_masks | |
elif mode in ["lastframe", "lastclip"]: | |
out_frames = expand_frames + source_frames | |
out_masks = expand_masks + source_masks | |
elif mode in ["firstlastframe", "firstlastclip"]: | |
source_frames_2 = [align_frames(source_frames[0], f2) for f2 in frames_2] | |
source_masks_2 = [np.zeros_like(source_frames_2[0][:, :, 0])] * len(frames_2) | |
out_frames = source_frames + expand_frames + source_frames_2 | |
out_masks = source_masks + expand_masks + source_masks_2 | |
else: | |
raise NotImplementedError | |
if return_dict: | |
ret_data = {"frames": out_frames} | |
if return_mask: | |
ret_data['masks'] = out_masks | |
return ret_data | |
else: | |
if return_mask: | |
return out_frames, out_masks | |
else: | |
return out_frames | |