|
""" |
|
Handy functions for flattening nested containers containing numpy |
|
arrays. The main purpose is to make examples and optimizers simpler. |
|
""" |
|
from autograd import make_vjp |
|
from autograd.builtins import type |
|
import autograd.numpy as np |
|
|
|
def flatten(value): |
|
"""Flattens any nesting of tuples, lists, or dicts, with numpy arrays or |
|
scalars inside. Returns 1D numpy array and an unflatten function. |
|
Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict |
|
keys are sortable.""" |
|
unflatten, flat_value = make_vjp(_flatten)(value) |
|
return flat_value, unflatten |
|
|
|
def _flatten(value): |
|
t = type(value) |
|
if t in (list, tuple): |
|
return _concatenate(map(_flatten, value)) |
|
elif t is dict: |
|
return _concatenate(_flatten(value[k]) for k in sorted(value)) |
|
else: |
|
return np.ravel(value) |
|
|
|
def _concatenate(lst): |
|
lst = list(lst) |
|
return np.concatenate(lst) if lst else np.array([]) |
|
|
|
def flatten_func(func, example): |
|
_ex, unflatten = flatten(example) |
|
_func = lambda _x, *args: flatten(func(unflatten(_x), *args))[0] |
|
return _func, unflatten, _ex |
|
|