import io import IPython.display import PIL.Image import os from pprint import pformat import numpy as np def imgrid(imarray, cols=4, pad=1, padval=255, row_major=True): """Lays out a [N, H, W, C] image array as a single image grid.""" pad = int(pad) if pad < 0: raise ValueError('pad must be non-negative') cols = int(cols) assert cols >= 1 N, H, W, C = imarray.shape rows = N // cols + int(N % cols != 0) batch_pad = rows * cols - N assert batch_pad >= 0 post_pad = [batch_pad, pad, pad, 0] pad_arg = [[0, p] for p in post_pad] imarray = np.pad(imarray, pad_arg, 'constant', constant_values=padval) H += pad W += pad grid = (imarray .reshape(rows, cols, H, W, C) .transpose(0, 2, 1, 3, 4) .reshape(rows*H, cols*W, C)) if pad: grid = grid[:-pad, :-pad] return grid def interleave(*args): """Interleaves input arrays of the same shape along the batch axis.""" if not args: raise ValueError('At least one argument is required.') a0 = args[0] if any(a.shape != a0.shape for a in args): raise ValueError('All inputs must have the same shape.') if not a0.shape: raise ValueError('Inputs must have at least one axis.') out = np.transpose(args, [1, 0] + list(range(2, len(a0.shape) + 1))) out = out.reshape(-1, *a0.shape[1:]) return out def imshow(a, format='png', jpeg_fallback=True): """Displays an image in the given format.""" a = a.astype(np.uint8) data = io.BytesIO() PIL.Image.fromarray(a).save(data, format) im_data = data.getvalue() try: disp = IPython.display.display(IPython.display.Image(im_data)) except IOError: if jpeg_fallback and format != 'jpeg': print ('Warning: image was too large to display in format "{}"; ' 'trying jpeg instead.').format(format) return imshow(a, format='jpeg') else: raise return disp def image_to_uint8(x): """Converts [-1, 1] float array to [0, 255] uint8.""" x = np.asarray(x) x = (256. / 2.) * (x + 1.) x = np.clip(x, 0, 255) x = x.astype(np.uint8) return x