HE-to-IHC / asp /util /fdlutil.py
antoinedelplace
First commit
207ef6f
import importlib.util
import os
import sys
from pylab import *
import matplotlib as mpl
# Use tkAgg when plotting to a window, Agg when to a file
# #### mpl.use('TkAgg') # Don't use this unless emergency. More trouble than it's worth
mpl.use('Agg')
def quick_imshow(nrows, ncols=1, images=None, titles=None, colorbar=True, colormap='jet',
vmax=None, vmin=None, figsize=None, figtitle=None, visibleaxis=True,
saveas='/home/ubuntu/tempimshow.png', tight=False, dpi=250.0):
"""-------------------------------------------------------------------------
Desc.: convenience function that make subplots of imshow
Args.: nrows - number of rows
ncols - number of cols
images - list of images
titles - list of titles
vmax - tuple of vmax for the colormap. If scalar,
the same value is used for all subplots. If one
of the entries is None, no colormap for that
subplot will be drawn.
vmin - tuple of vmin
Returns: f - the figure handle
axes - axes or array of axes objects
caxes - tuple of axes image
-------------------------------------------------------------------------"""
if isinstance(nrows, np.ndarray):
images = nrows
nrows = 1
ncols = 1
if figsize == None:
# 1.0 translates to 100 pixels of the figure
s = 5.0
if figtitle:
figsize = (s * ncols, s * nrows + 0.5)
else:
figsize = (s * ncols, s * nrows)
if nrows == ncols == 1:
if isinstance(images, list):
images = images[0]
f, ax = plt.subplots(figsize=figsize)
cax = ax.imshow(images, cmap=colormap, vmax=vmax, vmin=vmin)
if colorbar:
f.colorbar(cax, ax=ax)
if titles != None:
ax.set_title(titles)
if figtitle != None:
f.suptitle(figtitle)
cax.axes.get_xaxis().set_visible(visibleaxis)
cax.axes.get_yaxis().set_visible(visibleaxis)
if tight:
plt.tight_layout()
if len(saveas) > 0:
dirname = os.path.dirname(saveas)
if not os.path.exists(dirname):
os.makedirs(dirname)
plt.savefig(saveas)
return f, ax, cax
f, axes = plt.subplots(nrows, ncols, figsize=figsize, dpi=dpi)
caxes = []
i = 0
for ax, img in zip(axes.flat, images):
if isinstance(vmax, tuple) and isinstance(vmin, tuple):
if vmax[i] is not None and vmin[i] is not None:
cax = ax.imshow(img, cmap=colormap, vmax=vmax[i], vmin=vmin[i])
else:
cax = ax.imshow(img, cmap=colormap)
elif isinstance(vmax, tuple) and vmin is None:
if vmax[i] is not None:
cax = ax.imshow(img, cmap=colormap, vmax=vmax[i], vmin=0)
else:
cax = ax.imshow(img, cmap=colormap)
elif vmax is None and vmin is None:
cax = ax.imshow(img, cmap=colormap)
else:
cax = ax.imshow(img, cmap=colormap, vmax=vmax, vmin=vmin)
if titles != None:
ax.set_title(titles[i])
if colorbar:
f.colorbar(cax, ax=ax)
caxes.append(cax)
cax.axes.get_xaxis().set_visible(visibleaxis)
cax.axes.get_yaxis().set_visible(visibleaxis)
i = i + 1
if figtitle != None:
f.suptitle(figtitle)
if tight:
plt.tight_layout()
if len(saveas) > 0:
dirname = os.path.dirname(saveas)
if not os.path.exists(dirname):
os.makedirs(dirname)
plt.savefig(saveas)
return f, axes, tuple(caxes)
def update_subplots(images, caxes, f=None, axes=None, indices=(), vmax=None,
vmin=None):
"""-------------------------------------------------------------------------
Desc.: update subplots in a figure
Args.: images - new images to plot
caxes - caxes returned at figure creation
indices - specific indices of subplots to be updated
Returns:
-------------------------------------------------------------------------"""
for i in range(len(images)):
if len(indices) > 0:
ind = indices[i]
else:
ind = i
img = images[i]
caxes[ind].set_data(img)
cbar = caxes[ind].colorbar
if isinstance(vmax, tuple) and isinstance(vmin, tuple):
if vmax[i] is not None and vmin[i] is not None:
cbar.set_clim([vmin[i], vmax[i]])
else:
cbar.set_clim([img.min(), img.max()])
elif isinstance(vmax, tuple) and vmin is None:
if vmax[i] is not None:
cbar.set_clim([0, vmax[i]])
else:
cbar.set_clim([img.min(), img.max()])
elif vmax is None and vmin is None:
cbar.set_clim([img.min(), img.max()])
else:
cbar.set_clim([vmin, vmax])
cbar.update_normal(caxes[ind])
pause(0.01)
tight_layout()
def slide_show(image, dt=0.01, vmax=None, vmin=None):
"""
Slide show for visualizing an image volume. Image is (w, h, d)
:param image: (w, h, d), slides are 2D images along the depth axis
:param dt:
:param vmax:
:param vmin:
:return:
"""
if image.dtype == bool:
image *= 1.0
if vmax is None:
vmax = image.max()
if vmin is None:
vmin = image.min()
plt.ion()
plt.figure()
for i in range(image.shape[2]):
plt.cla()
cax = plt.imshow(image[:, :, i], cmap='jet', vmin=vmin, vmax=vmax)
plt.title(str('Slice: %i/%i' % (i, image.shape[2] - 1)))
if i == 0:
cf = plt.gcf()
ca = plt.gca()
cf.colorbar(cax, ax=ca)
plt.pause(dt)
plt.draw()
def quick_collage(images, nrows=3, ncols=2, normalize=False, figsize=(20.0, 10.0), figtitle=None, colorbar=True,
tight=True, saveas='/home/ubuntu/tempcollage.png'):
def zero_to_one(x):
if x.min() == x.max():
return x - x.min()
return (x.astype(float) - x.min()) / (x.max() - x.min())
# Normalize every image
if isinstance(images, np.ndarray):
images = [images]
# Check the shape and make sure everything is float
img_shp = images[0].shape
if normalize:
images = [zero_to_one(image) for image in images]
vmax, vmin = 1.0, 0.0
else:
vmax, vmin = max([img.max() for img in images]), min(
[img.min() for img in images])
# Highlight the boundaries
for i in range(0, len(images) - 1):
images[i] = np.hstack(
[images[i], np.full((img_shp[0], 1, img_shp[2]), np.nan)])
collage = np.hstack(images)
# Determine slice depth
depth = collage.shape[2]
n_slices = nrows * ncols
z = [int(depth / (n_slices + 1) * i - 1) for i in range(1, (n_slices + 1))]
titles = ['Slice %d/%d' % (i, depth) for i in z]
quick_imshow(
nrows, ncols,
[collage[:, :, z[i]] for i in range(n_slices)],
titles=titles,
figtitle=figtitle,
figsize=figsize,
vmax=vmax, vmin=vmin,
colorbar=colorbar, tight=tight)
if len(saveas) > 0:
plt.savefig(saveas)
plt.close()
def quick_plot(x_data, y_data=None, fmt='', color=None, xlim=None, ylim=None,
label='', legends=False, x_label='', y_label='', figtitle='', annotation=None, figsize=(20, 10),
f=None, ax=None, saveas=''):
if f is None or ax is None:
f, ax = subplots(figsize=figsize)
if y_data is None:
temp = x_data
x_data = list(range(len(temp)))
y_data = temp
ax.plot(x_data, y_data, fmt, label=label, color=color)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
if annotation is not None:
for i in range(len(x_data)):
annotate(annotation[i], (x_data[i], y_data[i]),
textcoords='offset points', xytext=(0, 10), ha='center')
if len(x_label) > 0:
ax.set_xlabel(x_label)
if len(y_label) > 0:
ax.set_ylabel(y_label)
if len(figtitle) > 0:
f.suptitle(figtitle)
if legends:
ax.legend(loc='center left', bbox_to_anchor=(1.04, 0.5))
ax.grid()
if len(saveas) > 0:
f.savefig(saveas, bbox_inches='tight')
ax.grid()
return f, ax
def quick_scatter(x_data, y_data=None, xlim=None, ylim=None,
label='', legends=False, x_label='', y_label='', figtitle='', annotation=None,
f=None, ax=None, saveas=''):
if f is None or ax is None:
f, ax = subplots()
if y_data is None:
temp = x_data
x_data = list(range(len(temp)))
y_data = temp
ax.scatter(x_data, y_data, label=label)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
if annotation is not None:
for i in range(len(x_data)):
annotate(annotation[i], (x_data[i], y_data[i]),
textcoords='offset points', xytext=(0, 10), ha='center')
if len(x_label) > 0:
ax.set_xlabel(x_label)
if len(y_label) > 0:
ax.set_ylabel(y_label)
if len(figtitle) > 0:
f.suptitle(figtitle)
if legends:
ax.legend()
ax.grid()
if len(saveas) > 0:
f.savefig(saveas)
return f, ax
def quick_load(file_path, fits_field=1):
if file_path.endswith('npz'):
with load(file_path, allow_pickle=True) as f:
data = f['arr_0']
# Take care of the case where a dictionary is saved in npz format
if isinstance(data, ndarray) and data.dtype == 'O':
data = data.flatten()[0]
# elif file_path.endswith(('pyc', 'pickle')):
# data = pickle_load(file_path)
# elif file_path.endswith('fits.gz'):
# data = read_fits_data(file_path, fits_field)
# elif file_path.endswith('h5'):
# data = read_hdf5_data(file_path)
else:
raise NotImplementedError(
"Only npz, pyc, h5 and fits.gz are supported!")
return data
def quick_save(file_path, data):
dir_name = os.path.dirname(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
# For better disk utilization and compatibility with fits, use int32
if file_path.endswith('npz'):
savez_compressed(file_path, data)
# elif file_path.endswith(('pyc', 'pickle')):
# save_object(file_path, data)
# elif file_path.endswith('fits.gz'):
# if isinstance(data, ndarray) and data.dtype == int:
# data = data.astype(int32)
# save_fits_data(file_path, data)
# elif file_path.endswith('h5'):
# write_hdf5_data(file_path, data)
else:
raise NotImplementedError(
"Only npz, pyc, h5 and fits.gz are supported!")
def import_module(name, path):
"""
correct way of importing a module dynamically in python 3.
:param name: name given to module instance.
:param path: path to module.
:return: module: returned module instance.
"""
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def obj_from_dict(info, parent=None, default_args=None):
"""Initialize an object from dict.
The dict must contain the key "type", which indicates the object type, it
can be either a string or type, such as "list" or ``list``. Remaining
fields are treated as the arguments for constructing the object.
Args:
info (dict): Object types and arguments.
parent (:class:`module`): Module which may containing expected object
classes.
default_args (dict, optional): Default arguments for initializing the
object.
Returns:
any type: Object built from the dict.
"""
assert isinstance(info, dict) and 'type' in info
assert isinstance(default_args, dict) or default_args is None
args = info.copy()
obj_type = args.pop('type')
if isinstance(obj_type, str):
if parent is not None:
obj_type = getattr(parent, obj_type)
else:
obj_type = sys.modules[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but '
f'got {type(obj_type)}')
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
"""
one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
:param image: nd image. can be anything
:param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
Example:
image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
:param mode: see np.pad for documentation
:param return_slicer: if True then this function will also return what coords you will need to use when cropping back
to original shape
:param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
:param kwargs: see np.pad for documentation
"""
if kwargs is None:
kwargs = {}
if new_shape is not None:
old_shape = np.array(image.shape[-len(new_shape):])
else:
assert shape_must_be_divisible_by is not None
assert isinstance(shape_must_be_divisible_by,
(list, tuple, np.ndarray))
new_shape = image.shape[-len(shape_must_be_divisible_by):]
old_shape = new_shape
num_axes_nopad = len(image.shape) - len(new_shape)
new_shape = [max(new_shape[i], old_shape[i])
for i in range(len(new_shape))]
if not isinstance(new_shape, np.ndarray):
new_shape = np.array(new_shape)
if shape_must_be_divisible_by is not None:
if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
shape_must_be_divisible_by = [
shape_must_be_divisible_by] * len(new_shape)
else:
assert len(shape_must_be_divisible_by) == len(new_shape)
for i in range(len(new_shape)):
if new_shape[i] % shape_must_be_divisible_by[i] == 0:
new_shape[i] -= shape_must_be_divisible_by[i]
new_shape = np.array(
[new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in
range(len(new_shape))])
difference = new_shape - old_shape
pad_below = difference // 2
pad_above = difference // 2 + difference % 2
pad_list = [[0, 0]] * num_axes_nopad + \
list([list(i) for i in zip(pad_below, pad_above)])
res = np.pad(image, pad_list, mode, **kwargs)
if not return_slicer:
return res
else:
pad_list = np.array(pad_list)
pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
slicer = list(slice(*i) for i in pad_list)
return res, slicer