HD-Painter / lib /utils /iimage.py
Andranik Sargsyan
enable fp16, move SR to cuda:1
da1e12f
import io
import math
import os
import warnings
import PIL.Image
import numpy as np
import cv2
import torch
import torchvision.transforms.functional as tvF
from scipy.ndimage import binary_dilation
def stack(images, axis = 0):
return IImage(np.concatenate([x.data for x in images], axis))
def torch2np(x, vmin=-1, vmax=1):
if x.ndim != 4:
# raise Exception("Please only use (B,C,H,W) torch tensors!")
warnings.warn(
"Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!")
if x.ndim == 3:
x = x[None]
if x.ndim == 2:
x = x[None, None]
x = x.detach().cpu().float()
if x.dtype == torch.uint8:
return x.numpy().astype(np.uint8)
elif vmin is not None and vmax is not None:
x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin))
x = x.permute(0, 2, 3, 1).to(torch.uint8)
return x.numpy()
else:
raise NotImplementedError()
class IImage:
@staticmethod
def open(path):
data = np.array(PIL.Image.open(path))
if data.ndim == 3:
data = data[..., None]
image = IImage(data)
return image
@staticmethod
def normalized(x, dims=[-1, -2]):
x = (x - x.amin(dims, True)) / \
(x.amax(dims, True) - x.amin(dims, True))
return IImage(x, 0)
def numpy(self): return self.data
def torch(self, vmin=-1, vmax=1):
if self.data.ndim == 3:
data = self.data.transpose(2, 0, 1) / 255.
else:
data = self.data.transpose(0, 3, 1, 2) / 255.
return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin)
def to(self, device):
self.device = device
return self
def cuda(self):
self.device = 'cuda'
return self
def cpu(self):
self.device = 'cpu'
return self
def pil(self):
ans = []
for x in self.data:
if x.shape[-1] == 1:
x = x[..., 0]
ans.append(PIL.Image.fromarray(x))
if len(ans) == 1:
return ans[0]
return ans
def is_iimage(self):
return True
@property
def shape(self): return self.data.shape
@property
def size(self): return (self.data.shape[-2], self.data.shape[-3])
def __init__(self, x, vmin=-1, vmax=1):
if isinstance(x, PIL.Image.Image):
self.data = np.array(x)
if self.data.ndim == 2:
self.data = self.data[..., None] # (H,W,C)
self.data = self.data[None] # (B,H,W,C)
elif isinstance(x, IImage):
self.data = x.data.copy() # Simple Copy
elif isinstance(x, np.ndarray):
self.data = x.copy().astype(np.uint8)
if self.data.ndim == 2:
self.data = self.data[None, ..., None]
if self.data.ndim == 3:
warnings.warn(
"Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)")
self.data = self.data[None]
elif isinstance(x, torch.Tensor):
self.data = torch2np(x, vmin, vmax)
self.device = 'cpu'
def resize(self, size, *args, **kwargs):
if size is None:
return self
use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False)
resample = kwargs.pop('filter', PIL.Image.BICUBIC) # Backward compatibility
resample = kwargs.pop('resample', resample)
if isinstance(size, int):
if use_small_edge_when_int:
h, w = self.data.shape[1:3]
aspect_ratio = h / w
size = (max(size, int(size * aspect_ratio)),
max(size, int(size / aspect_ratio)))
else:
h, w = self.data.shape[1:3]
aspect_ratio = h / w
size = (min(size, int(size * aspect_ratio)),
min(size, int(size / aspect_ratio)))
if self.size == size[::-1]:
return self
return stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self])
def pad(self, padding, *args, **kwargs):
return IImage(tvF.pad(self.torch(0), padding=padding, *args, **kwargs), 0)
def padx(self, multiplier, *args, **kwargs):
size = np.array(self.size)
padding = np.concatenate(
[[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size])
return self.pad(list(padding), *args, **kwargs)
def pad2wh(self, w=0, h=0, **kwargs):
cw, ch = self.size
return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs)
def pad2square(self, *args, **kwargs):
if self.size[0] > self.size[1]:
dx = self.size[0] - self.size[1]
return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs)
elif self.size[0] < self.size[1]:
dx = self.size[1] - self.size[0]
return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs)
return self
def alpha(self):
return IImage(self.data[..., -1, None])
def rgb(self):
return IImage(self.pil().convert('RGB'))
def dilate(self, iterations=1, *args, **kwargs):
return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8))
def save(self, path):
_, ext = os.path.splitext(path)
data = self.data if self.data.ndim == 3 else self.data[0]
PIL.Image.fromarray(data).save(path)
return self
def crop(self, bbox):
assert len(bbox) in [2,4]
if len(bbox) == 2:
x,y = 0,0
w,h = bbox
elif len(bbox) == 4:
x, y, w, h = bbox
return IImage(self.data[:, y:y+h, x:x+w, :])
def __getitem__(self, idx):
return IImage(self.data[None, idx])