File size: 8,515 Bytes
ab4488b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
"""Convenience functions built on top of `make_vjp`."""
from __future__ import absolute_import
from functools import partial
from collections import OrderedDict
try:
from inspect import getfullargspec as _getargspec # Python 3
except ImportError:
from inspect import getargspec as _getargspec # Python 2
import warnings
from .wrap_util import unary_to_nary
from .builtins import tuple as atuple
from .core import make_vjp as _make_vjp, make_jvp as _make_jvp
from .extend import primitive, defvjp_argnum, vspace
import autograd.numpy as np
make_vjp = unary_to_nary(_make_vjp)
make_jvp = unary_to_nary(_make_jvp)
@unary_to_nary
def grad(fun, x):
"""
Returns a function which computes the gradient of `fun` with respect to
positional argument number `argnum`. The returned function takes the same
arguments as `fun`, but returns the gradient instead. The function `fun`
should be scalar-valued. The gradient has the same type as the argument."""
vjp, ans = _make_vjp(fun, x)
if not vspace(ans).size == 1:
raise TypeError("Grad only applies to real scalar-output functions. "
"Try jacobian, elementwise_grad or holomorphic_grad.")
return vjp(vspace(ans).ones())
@unary_to_nary
def elementwise_grad(fun, x):
"""
Returns a function that computes the sum of each column of the Jacobian of
`fun`, in one pass. If the Jacobian is diagonal, then this is the diagonal
of the Jacobian.
"""
vjp, ans = _make_vjp(fun, x)
if vspace(ans).iscomplex:
raise TypeError("Elementwise_grad only applies to real-output functions.")
return vjp(vspace(ans).ones())
@unary_to_nary
def deriv(fun, x):
return _make_jvp(fun, x)(vspace(x).ones())[1]
@unary_to_nary
def jacobian(fun, x):
"""
Returns a function which computes the Jacobian of `fun` with respect to
positional argument number `argnum`, which must be a scalar or array. Unlike
`grad` it is not restricted to scalar-output functions, but also it cannot
take derivatives with respect to some argument types (like lists or dicts).
If the input to `fun` has shape (in1, in2, ...) and the output has shape
(out1, out2, ...) then the Jacobian has shape (out1, out2, ..., in1, in2, ...).
"""
vjp, ans = _make_vjp(fun, x)
ans_vspace = vspace(ans)
jacobian_shape = ans_vspace.shape + vspace(x).shape
grads = map(vjp, ans_vspace.standard_basis())
return np.reshape(np.stack(grads), jacobian_shape)
@unary_to_nary
def holomorphic_grad(fun, x):
if not vspace(x).iscomplex:
warnings.warn("Input to holomorphic_grad is not complex")
return grad(lambda x: np.real(fun(x)))(x)
def grad_named(fun, argname):
'''Takes gradients with respect to a named argument.
Doesn't work on *args or **kwargs.'''
arg_index = _getargspec(fun).args.index(argname)
return grad(fun, arg_index)
@unary_to_nary
def hessian(fun, x):
"Returns a function that computes the exact Hessian."
return jacobian(jacobian(fun))(x)
@unary_to_nary
def make_hvp(fun, x):
"""Builds a function for evaluating the Hessian-vector product at a point,
which may be useful when evaluating many Hessian-vector products at the same
point while caching the results of the forward pass."""
return _make_vjp(grad(fun), x)
def hessian_tensor_product(fun, argnum=0):
"""Builds a function that returns the exact Hessian-tensor product.
The returned function has arguments (*args, tensor, **kwargs), and for
vectors takes roughly 4x as long to evaluate as the original function."""
fun_grad = grad(fun, argnum)
def vector_dot_grad(*args, **kwargs):
args, vector = args[:-1], args[-1]
return np.tensordot(fun_grad(*args, **kwargs), vector, np.ndim(vector))
return grad(vector_dot_grad, argnum)
hessian_vector_product = hessian_tensor_product
def tensor_jacobian_product(fun, argnum=0):
"""Builds a function that returns the exact tensor-Jacobian product, that
is the Jacobian matrix left-multiplied by tensor. The returned function
has arguments (*args, tensor, **kwargs)."""
def vector_dot_fun(*args, **kwargs):
args, vector = args[:-1], args[-1]
return np.tensordot(vector, fun(*args, **kwargs), axes=np.ndim(vector))
return jacobian(vector_dot_fun, argnum)
vector_jacobian_product = tensor_jacobian_product
@unary_to_nary
def make_jvp_reversemode(fun, x):
"""Builds a function for evaluating the Jacobian-vector product at a
point. Roughly 1.5x more FLOPs than forward-mode, plus memory requirements
that scale with the number of primitives applied in the evaluation of f, as
well as other overheads. See j-towns.github.io/2017/06/12/A-new-trick.html."""
vjp, y = _make_vjp(fun, x)
vjp_vjp, _ = _make_vjp(vjp, vspace(y).zeros())
return vjp_vjp # vjp_vjp is just jvp by linearity
# TODO(mattjj): update this function using make_jvp and const_graph
def make_ggnvp(f, g=lambda x: 1./2*np.sum(x**2, axis=-1), f_argnum=0):
"""Builds a function for evaluating generalized-Gauss-Newton-vector products
at a point. Slightly more expensive than mixed-mode."""
@unary_to_nary
def _make_ggnvp(f, x):
f_vjp, f_x = _make_vjp(f, x)
g_hvp, grad_g_x = _make_vjp(grad(g), f_x)
f_jvp, _ = _make_vjp(f_vjp, vspace(grad_g_x).zeros())
def ggnvp(v): return f_vjp(g_hvp(f_jvp(v)))
return ggnvp
return _make_ggnvp(f, f_argnum)
@unary_to_nary
def value_and_grad(fun, x):
"""Returns a function that returns both value and gradient. Suitable for use
in scipy.optimize"""
vjp, ans = _make_vjp(fun, x)
if not vspace(ans).size == 1:
raise TypeError("value_and_grad only applies to real scalar-output "
"functions. Try jacobian, elementwise_grad or "
"holomorphic_grad.")
return ans, vjp(vspace(ans).ones())
@unary_to_nary
def grad_and_aux(fun, x):
"""Builds a function that returns the gradient of the first output and the
(unmodified) second output of a function that returns two outputs."""
vjp, (ans, aux) = _make_vjp(lambda x: atuple(fun(x)), x)
return vjp((vspace(ans).ones(), vspace(aux).zeros())), aux
def multigrad_dict(fun):
"Takes gradients wrt all arguments simultaneously,"
"returns a dict mapping 'argname' to 'gradval'"
import funcsigs
sig = funcsigs.signature(fun)
def select(preds, lst):
idx = lambda item: next(
(i for i, pred in enumerate(preds) if pred(item)), len(preds))
results = [[] for _ in preds] + [[]]
for item in lst:
results[idx(item)].append(item)
return results
is_var_pos = lambda name: sig.parameters[name].kind == sig.parameters[name].VAR_POSITIONAL
is_var_kwd = lambda name: sig.parameters[name].kind == sig.parameters[name].VAR_KEYWORD
var_pos, var_kwd, argnames = select([is_var_pos, is_var_kwd], sig.parameters)
todict = lambda dct: {key:dct[key] for key in dct}
def apply_defaults(arguments):
defaults = {name: param.default for name, param in sig.parameters.items()
if param.default is not param.empty}
return OrderedDict((name, arguments[name] if name in arguments else defaults[name])
for name in sig.parameters)
def gradfun(*args, **kwargs):
bindings = sig.bind(*args, **kwargs)
args = lambda dct: tuple(dct[var_pos[0]]) if var_pos else ()
kwargs = lambda dct: todict(dct[var_kwd[0]]) if var_kwd else {}
others = lambda dct: tuple(dct[argname] for argname in argnames
if argname not in var_kwd + var_pos)
newfun = lambda dct: fun(*(others(dct) + args(dct)), **kwargs(dct))
argdict = apply_defaults(bindings.arguments)
grad_dict = grad(newfun)(dict(argdict))
return OrderedDict((argname, grad_dict[argname]) for argname in argdict)
return gradfun
def checkpoint(fun):
"""Returns a checkpointed version of `fun`, where intermediate values
computed during the forward pass of `fun` are discarded and then recomputed
for the backward pass. Useful to save memory, effectively trading off time
and memory. See e.g. arxiv.org/abs/1604.06174.
"""
def wrapped_grad(argnum, ans, args, kwargs):
return make_vjp(fun, argnum)(*args, **kwargs)[0]
wrapped = primitive(fun)
defvjp_argnum(wrapped, wrapped_grad)
return wrapped
|