HE-to-IHC / asp /util /util.py
antoinedelplace
First commit
207ef6f
"""This module contains simple helper functions """
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os
import importlib
import argparse
from argparse import Namespace
import torchvision
import cv2 as cv
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def copyconf(default_opt, **kwargs):
conf = Namespace(**vars(default_opt))
for key in kwargs:
setattr(conf, key, kwargs[key])
return conf
def find_class_in_module(target_cls_name, module):
target_cls_name = target_cls_name.replace('_', '').lower()
clslib = importlib.import_module(module)
cls = None
for name, clsobj in clslib.__dict__.items():
if name.lower() == target_cls_name:
cls = clsobj
assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
return cls
def tensor2im(input_image, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def diagnose_network(net, name='network'):
"""Calculate and print the mean of average absolute(gradients)
Parameters:
net (torch network) -- Torch network
name (str) -- the name of the network
"""
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path, aspect_ratio=1.0):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""
image_pil = Image.fromarray(image_numpy)
h, w, _ = image_numpy.shape
if aspect_ratio is None:
pass
elif aspect_ratio > 1.0:
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
elif aspect_ratio < 1.0:
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
"""Print the mean, min, max, median, std, and size of a numpy array
Parameters:
val (bool) -- if print the values of the numpy array
shp (bool) -- if print the shape of the numpy array
"""
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
"""create empty directories if they don't exist
Parameters:
paths (str list) -- a list of directory paths
"""
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
"""create a single empty directory if it didn't exist
Parameters:
path (str) -- a single directory path
"""
if not os.path.exists(path):
os.makedirs(path)
def correct_resize_label(t, size):
device = t.device
t = t.detach().cpu()
resized = []
for i in range(t.size(0)):
one_t = t[i, :1]
one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
one_np = one_np[:, :, 0]
one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
resized_t = torch.from_numpy(np.array(one_image)).long()
resized.append(resized_t)
return torch.stack(resized, dim=0).to(device)
def correct_resize(t, size, mode=Image.BICUBIC):
device = t.device
t = t.detach().cpu()
resized = []
for i in range(t.size(0)):
one_t = t[i:i + 1]
one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
resized.append(resized_t)
return torch.stack(resized, dim=0).to(device)
def expand_as_one_hot(input, C, ignore_index=None):
"""
Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
It is assumed that the batch dimension is present.
Args:
input (torch.Tensor): 3D/4D input image
C (int): number of channels/labels
ignore_index (int): ignore index to be kept during the expansion
Returns:
4D/5D output torch.Tensor (NxCxSPATIAL)
"""
# expand the input tensor to Nx1xSPATIAL before scattering
input = input.unsqueeze(1)
# create output tensor shape (NxCxSPATIAL)
shape = list(input.size())
shape[1] = C
if ignore_index is not None:
# create ignore_index mask for the result
mask = input.expand(shape) == ignore_index
# clone the src tensor and zero out ignore_index in the input
input = input.clone()
input[input == ignore_index] = 0
# scatter to get the one-hot tensor
result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
# bring back the ignore_index in the result
result[mask] = ignore_index
return result
else:
# scatter to get the one-hot tensor
return torch.zeros(shape).to(input.device).scatter_(1, input, 1)
def standardize(ref, I, threshold=50):
"""
Transform image I to standard brightness.
Modifies the luminosity channel such that a fixed percentile is saturated.
:param I: Image uint8 RGB.
:param percentile: Percentile for luminosity saturation. At least (100 - percentile)% of pixels should be fully luminous (white).
:return: Image uint8 RGB with standardized brightness.
"""
ref_m = cv.cvtColor(ref, cv.COLOR_RGB2LAB)[:, :, 0].astype(float).mean()
I_LAB = cv.cvtColor(I, cv.COLOR_RGB2LAB)
L_float = I_LAB[:, :, 0].astype(float)
tgt_m = L_float.mean()
if np.abs(tgt_m - ref_m) > threshold:
L_float = L_float - tgt_m + ref_m
I_LAB[:, :, 0] = np.clip(L_float, 0, 255).astype(np.uint8)
I = cv.cvtColor(I_LAB, cv.COLOR_LAB2RGB)
return I