File size: 3,424 Bytes
bcb05d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
from PIL import Image
import torch
import random
from torchvision import transforms
import torchvision.transforms.functional as TF


def apply_joint_transforms(rgb, mask, img_size, img_aug=True, test=True):
    if test:
        extra_pad = 16
    else:
        extra_pad = random.randint(0, 32)
    W_img, H_img = rgb.size[:2]
    max_HW = max(H_img, W_img)
    top_pad = (max_HW - H_img) // 2
    bottom_pad = max_HW - H_img - top_pad
    left_pad = (max_HW - W_img) // 2
    right_pad = max_HW - W_img - left_pad

    # 1. padding
    rgb = TF.pad(rgb, (left_pad, top_pad, right_pad, bottom_pad), fill=255)
    mask = TF.pad(mask, (left_pad, top_pad, right_pad, bottom_pad), fill=0) 

    if img_aug and (not test):
        # 2. random rotate
        if random.random() < 0.1:
            angle = random.uniform(-10, 10)
            rgb = TF.rotate(rgb, angle, fill=255)
            mask = TF.rotate(mask, angle, fill=0)

        # 3. random crop
        if random.random() < 0.1:
            crop_ratio = random.uniform(0.9, 1.0)
            crop_size = int(max_HW * crop_ratio)
            i, j, h, w = transforms.RandomCrop.get_params(rgb, (crop_size, crop_size))
            rgb = TF.crop(rgb, i, j, h, w)
            mask = TF.crop(mask, i, j, h, w)

    # 4. resize
    target_size = (img_size, img_size)
    rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR)
    mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST)

    # 5. extra padding
    rgb = TF.pad(rgb, extra_pad, fill=255)
    mask = TF.pad(mask, extra_pad, fill=0)
    rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR)
    mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST)

    # to tensor
    rgb_tensor = TF.to_tensor(rgb)
    mask_tensor = TF.to_tensor(mask)
    
    return rgb_tensor, mask_tensor

def crop_recenter(image_no_bg, thereshold=100):
    image_no_bg_np = np.array(image_no_bg)
    mask = (image_no_bg_np[..., -1]).astype(np.uint8)
    mask_bin = mask > thereshold
    
    H, W = image_no_bg_np.shape[:2]
    
    valid_pixels = mask_bin.astype(np.float32).nonzero() # [N, 2]
    if np.sum(mask_bin) < (H*W) * 0.001:
        min_h =0
        max_h = H - 1
        min_w = 0
        max_w = W -1
    else:
        min_h, max_h = valid_pixels[0].min(), valid_pixels[0].max()
        min_w, max_w = valid_pixels[1].min(), valid_pixels[1].max()
        
    if min_h < 0:
        min_h = 0
    if min_w < 0:
        min_w = 0
    if max_h > H:
        max_h = H 
    if max_w > W:
        max_w = W

    image_no_bg_np = image_no_bg_np[min_h:max_h+1, min_w:max_w+1]
    return image_no_bg_np

def preprocess_image(img):

    if isinstance(img, str):
        img = Image.open(img)
        img = np.array(img)
    elif isinstance(img, Image.Image):
        img = np.array(img)

    if img.shape[-1] == 3:
        mask = np.ones_like(img[..., 0:1])
        img = np.concatenate([img, mask], axis=-1)

    img = crop_recenter(img, thereshold=0) / 255.

    mask = img[..., 3]
    img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:])
    img = Image.fromarray((img * 255).astype(np.uint8))
    mask = Image.fromarray((mask * 255).astype(np.uint8))

    img, mask = apply_joint_transforms(img, mask, img_size=518, 
            img_aug=False, test=True)
    img = torch.cat([img, mask], dim=0)
    return img