Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import numpy as np | |
import torch | |
from .utils import convert_to_numpy | |
class FaceAnnotator: | |
def __init__(self, cfg, device=None): | |
from insightface.app import FaceAnalysis | |
self.return_raw = cfg.get('RETURN_RAW', True) | |
self.return_mask = cfg.get('RETURN_MASK', False) | |
self.return_dict = cfg.get('RETURN_DICT', False) | |
self.multi_face = cfg.get('MULTI_FACE', True) | |
pretrained_model = cfg['PRETRAINED_MODEL'] | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device | |
self.device_id = self.device.index if self.device.type == 'cuda' else None | |
ctx_id = self.device_id if self.device_id is not None else 0 | |
self.model = FaceAnalysis(name=cfg.MODEL_NAME, root=pretrained_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
self.model.prepare(ctx_id=ctx_id, det_size=(640, 640)) | |
def forward(self, image=None, return_mask=None, return_dict=None): | |
return_mask = return_mask if return_mask is not None else self.return_mask | |
return_dict = return_dict if return_dict is not None else self.return_dict | |
image = convert_to_numpy(image) | |
# [dict_keys(['bbox', 'kps', 'det_score', 'landmark_3d_68', 'pose', 'landmark_2d_106', 'gender', 'age', 'embedding'])] | |
faces = self.model.get(image) | |
if self.return_raw: | |
return faces | |
else: | |
crop_face_list, mask_list = [], [] | |
if len(faces) > 0: | |
if not self.multi_face: | |
faces = faces[:1] | |
for face in faces: | |
x_min, y_min, x_max, y_max = face['bbox'].tolist() | |
crop_face = image[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1] | |
crop_face_list.append(crop_face) | |
mask = np.zeros_like(image[:, :, 0]) | |
mask[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1] = 255 | |
mask_list.append(mask) | |
if not self.multi_face: | |
crop_face_list = crop_face_list[0] | |
mask_list = mask_list[0] | |
if return_mask: | |
if return_dict: | |
return {'image': crop_face_list, 'mask': mask_list} | |
else: | |
return crop_face_list, mask_list | |
else: | |
return crop_face_list | |
else: | |
return None | |