|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
|
|
from .model import BiSeNet |
|
|
|
CHECKPOINT = 'face_parsing/cp/79999_iter.pth' |
|
net = None |
|
|
|
def vis_parsing_maps(im, parsing_anno, stride): |
|
part_colors = [[0, 0, 0], |
|
[255, 255, 255], |
|
[255, 255, 255], |
|
[255, 255, 255], |
|
[255, 255, 255], |
|
[255, 255, 255], |
|
[255, 255, 255], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[255, 255, 255], |
|
[255, 255, 255], |
|
[255, 255, 255], |
|
[255, 255, 255], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[0, 0, 0], |
|
[0, 0, 0]] |
|
|
|
im = np.array(im) |
|
vis_im = im.copy().astype(np.uint8) |
|
vis_parsing_anno = parsing_anno.copy().astype(np.uint8) |
|
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) |
|
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 |
|
|
|
num_of_class = np.max(vis_parsing_anno) |
|
|
|
for pi in range(0, num_of_class + 1): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] |
|
|
|
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) |
|
vis_parsing_anno_color = Image.fromarray(vis_parsing_anno_color) |
|
|
|
return vis_parsing_anno_color |
|
|
|
def get_face_mask(pil_img): |
|
global net |
|
|
|
if net is None: |
|
n_classes = 19 |
|
net = BiSeNet(n_classes=n_classes) |
|
if torch.cuda.is_available(): |
|
net.cuda() |
|
net.load_state_dict(torch.load(CHECKPOINT, map_location='cpu')) |
|
net.eval() |
|
|
|
to_tensor = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
|
|
with torch.no_grad(): |
|
origw, origh = pil_img.size |
|
|
|
image = pil_img.resize((512, 512), Image.BILINEAR) |
|
img = to_tensor(image) |
|
img = torch.unsqueeze(img, 0) |
|
if torch.cuda.is_available(): |
|
img = img.cuda() |
|
out = net(img)[0] |
|
parsing = out.squeeze(0).cpu().numpy().argmax(0) |
|
|
|
print(np.unique(parsing)) |
|
|
|
mask = vis_parsing_maps(image, parsing, stride=1) |
|
mask = mask.resize((origw, origh), Image.BILINEAR) |
|
|
|
return mask |
|
|