|
"""Convenience functions built on top of `make_vjp`.""" |
|
from __future__ import absolute_import |
|
from functools import partial |
|
from collections import OrderedDict |
|
try: |
|
from inspect import getfullargspec as _getargspec |
|
except ImportError: |
|
from inspect import getargspec as _getargspec |
|
import warnings |
|
|
|
from .wrap_util import unary_to_nary |
|
from .builtins import tuple as atuple |
|
from .core import make_vjp as _make_vjp, make_jvp as _make_jvp |
|
from .extend import primitive, defvjp_argnum, vspace |
|
|
|
import autograd.numpy as np |
|
|
|
make_vjp = unary_to_nary(_make_vjp) |
|
make_jvp = unary_to_nary(_make_jvp) |
|
|
|
@unary_to_nary |
|
def grad(fun, x): |
|
""" |
|
Returns a function which computes the gradient of `fun` with respect to |
|
positional argument number `argnum`. The returned function takes the same |
|
arguments as `fun`, but returns the gradient instead. The function `fun` |
|
should be scalar-valued. The gradient has the same type as the argument.""" |
|
vjp, ans = _make_vjp(fun, x) |
|
if not vspace(ans).size == 1: |
|
raise TypeError("Grad only applies to real scalar-output functions. " |
|
"Try jacobian, elementwise_grad or holomorphic_grad.") |
|
return vjp(vspace(ans).ones()) |
|
|
|
@unary_to_nary |
|
def elementwise_grad(fun, x): |
|
""" |
|
Returns a function that computes the sum of each column of the Jacobian of |
|
`fun`, in one pass. If the Jacobian is diagonal, then this is the diagonal |
|
of the Jacobian. |
|
""" |
|
vjp, ans = _make_vjp(fun, x) |
|
if vspace(ans).iscomplex: |
|
raise TypeError("Elementwise_grad only applies to real-output functions.") |
|
return vjp(vspace(ans).ones()) |
|
|
|
@unary_to_nary |
|
def deriv(fun, x): |
|
return _make_jvp(fun, x)(vspace(x).ones())[1] |
|
|
|
@unary_to_nary |
|
def jacobian(fun, x): |
|
""" |
|
Returns a function which computes the Jacobian of `fun` with respect to |
|
positional argument number `argnum`, which must be a scalar or array. Unlike |
|
`grad` it is not restricted to scalar-output functions, but also it cannot |
|
take derivatives with respect to some argument types (like lists or dicts). |
|
If the input to `fun` has shape (in1, in2, ...) and the output has shape |
|
(out1, out2, ...) then the Jacobian has shape (out1, out2, ..., in1, in2, ...). |
|
""" |
|
vjp, ans = _make_vjp(fun, x) |
|
ans_vspace = vspace(ans) |
|
jacobian_shape = ans_vspace.shape + vspace(x).shape |
|
grads = map(vjp, ans_vspace.standard_basis()) |
|
return np.reshape(np.stack(grads), jacobian_shape) |
|
|
|
@unary_to_nary |
|
def holomorphic_grad(fun, x): |
|
if not vspace(x).iscomplex: |
|
warnings.warn("Input to holomorphic_grad is not complex") |
|
return grad(lambda x: np.real(fun(x)))(x) |
|
|
|
def grad_named(fun, argname): |
|
'''Takes gradients with respect to a named argument. |
|
Doesn't work on *args or **kwargs.''' |
|
arg_index = _getargspec(fun).args.index(argname) |
|
return grad(fun, arg_index) |
|
|
|
@unary_to_nary |
|
def hessian(fun, x): |
|
"Returns a function that computes the exact Hessian." |
|
return jacobian(jacobian(fun))(x) |
|
|
|
@unary_to_nary |
|
def make_hvp(fun, x): |
|
"""Builds a function for evaluating the Hessian-vector product at a point, |
|
which may be useful when evaluating many Hessian-vector products at the same |
|
point while caching the results of the forward pass.""" |
|
return _make_vjp(grad(fun), x) |
|
|
|
def hessian_tensor_product(fun, argnum=0): |
|
"""Builds a function that returns the exact Hessian-tensor product. |
|
The returned function has arguments (*args, tensor, **kwargs), and for |
|
vectors takes roughly 4x as long to evaluate as the original function.""" |
|
fun_grad = grad(fun, argnum) |
|
def vector_dot_grad(*args, **kwargs): |
|
args, vector = args[:-1], args[-1] |
|
return np.tensordot(fun_grad(*args, **kwargs), vector, np.ndim(vector)) |
|
return grad(vector_dot_grad, argnum) |
|
hessian_vector_product = hessian_tensor_product |
|
|
|
def tensor_jacobian_product(fun, argnum=0): |
|
"""Builds a function that returns the exact tensor-Jacobian product, that |
|
is the Jacobian matrix left-multiplied by tensor. The returned function |
|
has arguments (*args, tensor, **kwargs).""" |
|
def vector_dot_fun(*args, **kwargs): |
|
args, vector = args[:-1], args[-1] |
|
return np.tensordot(vector, fun(*args, **kwargs), axes=np.ndim(vector)) |
|
return jacobian(vector_dot_fun, argnum) |
|
vector_jacobian_product = tensor_jacobian_product |
|
|
|
@unary_to_nary |
|
def make_jvp_reversemode(fun, x): |
|
"""Builds a function for evaluating the Jacobian-vector product at a |
|
point. Roughly 1.5x more FLOPs than forward-mode, plus memory requirements |
|
that scale with the number of primitives applied in the evaluation of f, as |
|
well as other overheads. See j-towns.github.io/2017/06/12/A-new-trick.html.""" |
|
vjp, y = _make_vjp(fun, x) |
|
vjp_vjp, _ = _make_vjp(vjp, vspace(y).zeros()) |
|
return vjp_vjp |
|
|
|
|
|
def make_ggnvp(f, g=lambda x: 1./2*np.sum(x**2, axis=-1), f_argnum=0): |
|
"""Builds a function for evaluating generalized-Gauss-Newton-vector products |
|
at a point. Slightly more expensive than mixed-mode.""" |
|
@unary_to_nary |
|
def _make_ggnvp(f, x): |
|
f_vjp, f_x = _make_vjp(f, x) |
|
g_hvp, grad_g_x = _make_vjp(grad(g), f_x) |
|
f_jvp, _ = _make_vjp(f_vjp, vspace(grad_g_x).zeros()) |
|
def ggnvp(v): return f_vjp(g_hvp(f_jvp(v))) |
|
return ggnvp |
|
return _make_ggnvp(f, f_argnum) |
|
|
|
@unary_to_nary |
|
def value_and_grad(fun, x): |
|
"""Returns a function that returns both value and gradient. Suitable for use |
|
in scipy.optimize""" |
|
vjp, ans = _make_vjp(fun, x) |
|
if not vspace(ans).size == 1: |
|
raise TypeError("value_and_grad only applies to real scalar-output " |
|
"functions. Try jacobian, elementwise_grad or " |
|
"holomorphic_grad.") |
|
return ans, vjp(vspace(ans).ones()) |
|
|
|
@unary_to_nary |
|
def grad_and_aux(fun, x): |
|
"""Builds a function that returns the gradient of the first output and the |
|
(unmodified) second output of a function that returns two outputs.""" |
|
vjp, (ans, aux) = _make_vjp(lambda x: atuple(fun(x)), x) |
|
return vjp((vspace(ans).ones(), vspace(aux).zeros())), aux |
|
|
|
def multigrad_dict(fun): |
|
"Takes gradients wrt all arguments simultaneously," |
|
"returns a dict mapping 'argname' to 'gradval'" |
|
|
|
import funcsigs |
|
sig = funcsigs.signature(fun) |
|
|
|
def select(preds, lst): |
|
idx = lambda item: next( |
|
(i for i, pred in enumerate(preds) if pred(item)), len(preds)) |
|
results = [[] for _ in preds] + [[]] |
|
for item in lst: |
|
results[idx(item)].append(item) |
|
return results |
|
|
|
is_var_pos = lambda name: sig.parameters[name].kind == sig.parameters[name].VAR_POSITIONAL |
|
is_var_kwd = lambda name: sig.parameters[name].kind == sig.parameters[name].VAR_KEYWORD |
|
var_pos, var_kwd, argnames = select([is_var_pos, is_var_kwd], sig.parameters) |
|
|
|
todict = lambda dct: {key:dct[key] for key in dct} |
|
|
|
def apply_defaults(arguments): |
|
defaults = {name: param.default for name, param in sig.parameters.items() |
|
if param.default is not param.empty} |
|
return OrderedDict((name, arguments[name] if name in arguments else defaults[name]) |
|
for name in sig.parameters) |
|
|
|
def gradfun(*args, **kwargs): |
|
bindings = sig.bind(*args, **kwargs) |
|
|
|
args = lambda dct: tuple(dct[var_pos[0]]) if var_pos else () |
|
kwargs = lambda dct: todict(dct[var_kwd[0]]) if var_kwd else {} |
|
others = lambda dct: tuple(dct[argname] for argname in argnames |
|
if argname not in var_kwd + var_pos) |
|
|
|
newfun = lambda dct: fun(*(others(dct) + args(dct)), **kwargs(dct)) |
|
|
|
argdict = apply_defaults(bindings.arguments) |
|
grad_dict = grad(newfun)(dict(argdict)) |
|
return OrderedDict((argname, grad_dict[argname]) for argname in argdict) |
|
|
|
return gradfun |
|
|
|
def checkpoint(fun): |
|
"""Returns a checkpointed version of `fun`, where intermediate values |
|
computed during the forward pass of `fun` are discarded and then recomputed |
|
for the backward pass. Useful to save memory, effectively trading off time |
|
and memory. See e.g. arxiv.org/abs/1604.06174. |
|
""" |
|
def wrapped_grad(argnum, ans, args, kwargs): |
|
return make_vjp(fun, argnum)(*args, **kwargs)[0] |
|
wrapped = primitive(fun) |
|
defvjp_argnum(wrapped, wrapped_grad) |
|
return wrapped |
|
|