File size: 2,882 Bytes
5ba0490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image

from .model import BiSeNet

CHECKPOINT = 'face_parsing/cp/79999_iter.pth'
net = None

def vis_parsing_maps(im, parsing_anno, stride):
    part_colors = [[0, 0, 0],  # bg
                   [255, 255, 255],  # skin
                   [255, 255, 255],  # l_brow
                   [255, 255, 255],  # r_brow
                   [255, 255, 255],  # l_eye
                   [255, 255, 255],  # r_eye
                   [255, 255, 255],  # eye_g
                   [0, 0, 0],  # l_ear
                   [0, 0, 0],  # r_ear
                   [0, 0, 0],  # ear_r
                   [255, 255, 255],  # nose
                   [255, 255, 255],  # mouth
                   [255, 255, 255],  # u_lip
                   [255, 255, 255],  # l_lip
                   [0, 0, 0],  # neck
                   [0, 0, 0],  # neck_l
                   [0, 0, 0],  # cloth
                   [0, 0, 0],  # hair
                   [0, 0, 0],  # hat
                   [0, 0, 0],
                   [0, 0, 0],
                   [0, 0, 0],
                   [0, 0, 0],
                   [0, 0, 0]]

    im = np.array(im)
    vis_im = im.copy().astype(np.uint8)
    vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
    vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
    vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255

    num_of_class = np.max(vis_parsing_anno)

    for pi in range(0, num_of_class + 1):
        index = np.where(vis_parsing_anno == pi)
        vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]

    vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
    vis_parsing_anno_color = Image.fromarray(vis_parsing_anno_color)

    return vis_parsing_anno_color

def get_face_mask(pil_img):
    global net

    if net is None:
        n_classes = 19
        net = BiSeNet(n_classes=n_classes)
        if torch.cuda.is_available():
            net.cuda()
        net.load_state_dict(torch.load(CHECKPOINT, map_location='cpu'))
        net.eval()

    to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    with torch.no_grad():
        origw, origh = pil_img.size

        image = pil_img.resize((512, 512), Image.BILINEAR)
        img = to_tensor(image)
        img = torch.unsqueeze(img, 0)
        if torch.cuda.is_available():
            img = img.cuda()
        out = net(img)[0]
        parsing = out.squeeze(0).cpu().numpy().argmax(0)
        # print(parsing)
        print(np.unique(parsing))

        mask = vis_parsing_maps(image, parsing, stride=1)
        mask = mask.resize((origw, origh), Image.BILINEAR)

    return mask