|
from PIL import Image |
|
|
|
import numpy as np |
|
import torch |
|
from torchvision import transforms |
|
from skimage.transform import resize |
|
|
|
from .u2net import U2NET |
|
|
|
|
|
def get_mask_u2net(pil_im, output_dir, u2net_path, device="cpu"): |
|
|
|
w, h = pil_im.size[0], pil_im.size[1] |
|
im_size = min(w, h) |
|
data_transforms = transforms.Compose([ |
|
transforms.Resize(min(320, im_size), interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), |
|
std=(0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
input_im_trans = data_transforms(pil_im).unsqueeze(0).to(device) |
|
|
|
|
|
net = U2NET(in_ch=3, out_ch=1) |
|
net.load_state_dict(torch.load(u2net_path)) |
|
net.to(device) |
|
net.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach()) |
|
pred = d1[:, 0, :, :] |
|
pred = (pred - pred.min()) / (pred.max() - pred.min()) |
|
predict = pred |
|
predict[predict < 0.5] = 0 |
|
predict[predict >= 0.5] = 1 |
|
mask = torch.cat([predict, predict, predict], dim=0).permute(1, 2, 0) |
|
mask = mask.cpu().numpy() |
|
mask = resize(mask, (h, w), anti_aliasing=False) |
|
mask[mask < 0.5] = 0 |
|
mask[mask >= 0.5] = 1 |
|
|
|
|
|
im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB') |
|
save_path_ = output_dir / "mask.png" |
|
im.save(save_path_) |
|
|
|
im_np = np.array(pil_im) |
|
im_np = im_np / im_np.max() |
|
im_np = mask * im_np |
|
im_np[mask == 0] = 1 |
|
im_final = (im_np / im_np.max() * 255).astype(np.uint8) |
|
im_final = Image.fromarray(im_final) |
|
|
|
|
|
del net |
|
torch.cuda.empty_cache() |
|
|
|
return im_final, predict |
|
|