gorgeous / face_parsing /inference.py
Kam-Woh Ng
push to hf
5ba0490
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from .model import BiSeNet
CHECKPOINT = 'face_parsing/cp/79999_iter.pth'
net = None
def vis_parsing_maps(im, parsing_anno, stride):
part_colors = [[0, 0, 0], # bg
[255, 255, 255], # skin
[255, 255, 255], # l_brow
[255, 255, 255], # r_brow
[255, 255, 255], # l_eye
[255, 255, 255], # r_eye
[255, 255, 255], # eye_g
[0, 0, 0], # l_ear
[0, 0, 0], # r_ear
[0, 0, 0], # ear_r
[255, 255, 255], # nose
[255, 255, 255], # mouth
[255, 255, 255], # u_lip
[255, 255, 255], # l_lip
[0, 0, 0], # neck
[0, 0, 0], # neck_l
[0, 0, 0], # cloth
[0, 0, 0], # hair
[0, 0, 0], # hat
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]
im = np.array(im)
vis_im = im.copy().astype(np.uint8)
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
num_of_class = np.max(vis_parsing_anno)
for pi in range(0, num_of_class + 1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
vis_parsing_anno_color = Image.fromarray(vis_parsing_anno_color)
return vis_parsing_anno_color
def get_face_mask(pil_img):
global net
if net is None:
n_classes = 19
net = BiSeNet(n_classes=n_classes)
if torch.cuda.is_available():
net.cuda()
net.load_state_dict(torch.load(CHECKPOINT, map_location='cpu'))
net.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
with torch.no_grad():
origw, origh = pil_img.size
image = pil_img.resize((512, 512), Image.BILINEAR)
img = to_tensor(image)
img = torch.unsqueeze(img, 0)
if torch.cuda.is_available():
img = img.cuda()
out = net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
# print(parsing)
print(np.unique(parsing))
mask = vis_parsing_maps(image, parsing, stride=1)
mask = mask.resize((origw, origh), Image.BILINEAR)
return mask