|
from itertools import count |
|
from functools import reduce |
|
from .tracer import trace, primitive, toposort, Node, Box, isbox, getval |
|
from .util import func, subval |
|
|
|
|
|
|
|
def make_vjp(fun, x): |
|
start_node = VJPNode.new_root() |
|
end_value, end_node = trace(start_node, fun, x) |
|
if end_node is None: |
|
def vjp(g): return vspace(x).zeros() |
|
else: |
|
def vjp(g): return backward_pass(g, end_node) |
|
return vjp, end_value |
|
|
|
def backward_pass(g, end_node): |
|
outgrads = {end_node : (g, False)} |
|
for node in toposort(end_node): |
|
outgrad = outgrads.pop(node) |
|
ingrads = node.vjp(outgrad[0]) |
|
for parent, ingrad in zip(node.parents, ingrads): |
|
outgrads[parent] = add_outgrads(outgrads.get(parent), ingrad) |
|
return outgrad[0] |
|
|
|
class VJPNode(Node): |
|
__slots__ = ['parents', 'vjp'] |
|
def __init__(self, value, fun, args, kwargs, parent_argnums, parents): |
|
self.parents = parents |
|
try: |
|
vjpmaker = primitive_vjps[fun] |
|
except KeyError: |
|
fun_name = getattr(fun, '__name__', fun) |
|
raise NotImplementedError("VJP of {} wrt argnums {} not defined" |
|
.format(fun_name, parent_argnums)) |
|
self.vjp = vjpmaker(parent_argnums, value, args, kwargs) |
|
|
|
def initialize_root(self): |
|
self.parents = [] |
|
self.vjp = lambda g: () |
|
|
|
primitive_vjps = {} |
|
def defvjp_argnums(fun, vjpmaker): |
|
primitive_vjps[fun] = vjpmaker |
|
|
|
def defvjp_argnum(fun, vjpmaker): |
|
def vjp_argnums(argnums, *args): |
|
vjps = [vjpmaker(argnum, *args) for argnum in argnums] |
|
return lambda g: (vjp(g) for vjp in vjps) |
|
defvjp_argnums(fun, vjp_argnums) |
|
|
|
def defvjp(fun, *vjpmakers, **kwargs): |
|
argnums = kwargs.get('argnums', count()) |
|
vjps_dict = {argnum : translate_vjp(vjpmaker, fun, argnum) |
|
for argnum, vjpmaker in zip(argnums, vjpmakers)} |
|
def vjp_argnums(argnums, ans, args, kwargs): |
|
L = len(argnums) |
|
|
|
if L == 1: |
|
argnum = argnums[0] |
|
try: |
|
vjpfun = vjps_dict[argnum] |
|
except KeyError: |
|
raise NotImplementedError( |
|
"VJP of {} wrt argnum 0 not defined".format(fun.__name__)) |
|
vjp = vjpfun(ans, *args, **kwargs) |
|
return lambda g: (vjp(g),) |
|
elif L == 2: |
|
argnum_0, argnum_1 = argnums |
|
try: |
|
vjp_0_fun = vjps_dict[argnum_0] |
|
vjp_1_fun = vjps_dict[argnum_1] |
|
except KeyError: |
|
raise NotImplementedError( |
|
"VJP of {} wrt argnums 0, 1 not defined".format(fun.__name__)) |
|
vjp_0 = vjp_0_fun(ans, *args, **kwargs) |
|
vjp_1 = vjp_1_fun(ans, *args, **kwargs) |
|
return lambda g: (vjp_0(g), vjp_1(g)) |
|
else: |
|
vjps = [vjps_dict[argnum](ans, *args, **kwargs) for argnum in argnums] |
|
return lambda g: (vjp(g) for vjp in vjps) |
|
|
|
defvjp_argnums(fun, vjp_argnums) |
|
|
|
def translate_vjp(vjpfun, fun, argnum): |
|
if vjpfun is None: |
|
return lambda ans, *args, **kwargs: lambda g: vspace(args[argnum]).zeros() |
|
elif callable(vjpfun): |
|
return vjpfun |
|
else: |
|
raise Exception("Bad VJP '{}' for '{}'".format(vjpfun, fun.__name__)) |
|
|
|
|
|
|
|
def make_jvp(fun, x): |
|
def jvp(g): |
|
start_node = JVPNode.new_root(g) |
|
end_value, end_node = trace(start_node, fun, x) |
|
if end_node is None: |
|
return end_value, vspace(end_value).zeros() |
|
else: |
|
return end_value, end_node.g |
|
return jvp |
|
|
|
class JVPNode(Node): |
|
__slots__ = ['g'] |
|
def __init__(self, value, fun, args, kwargs, parent_argnums, parents): |
|
parent_gs = [parent.g for parent in parents] |
|
try: |
|
jvpmaker = primitive_jvps[fun] |
|
except KeyError: |
|
name = getattr(fun, '__name__', fun) |
|
raise NotImplementedError("JVP of {} wrt argnums {} not defined" |
|
.format(name, parent_argnums)) |
|
self.g = jvpmaker(parent_argnums, parent_gs, value, args, kwargs) |
|
|
|
def initialize_root(self, g): |
|
self.g = g |
|
|
|
primitive_jvps = {} |
|
def defjvp_argnums(fun, jvpmaker): |
|
primitive_jvps[fun] = jvpmaker |
|
|
|
def defjvp_argnum(fun, jvpmaker): |
|
def jvp_argnums(argnums, gs, ans, args, kwargs): |
|
return sum_outgrads(jvpmaker(argnum, g, ans, args, kwargs) |
|
for argnum, g in zip(argnums, gs)) |
|
defjvp_argnums(fun, jvp_argnums) |
|
|
|
def defjvp(fun, *jvpfuns, **kwargs): |
|
argnums = kwargs.get('argnums', count()) |
|
jvps_dict = {argnum : translate_jvp(jvpfun, fun, argnum) |
|
for argnum, jvpfun in zip(argnums, jvpfuns)} |
|
def jvp_argnums(argnums, gs, ans, args, kwargs): |
|
return sum_outgrads(jvps_dict[argnum](g, ans, *args, **kwargs) |
|
for argnum, g in zip(argnums, gs)) |
|
|
|
defjvp_argnums(fun, jvp_argnums) |
|
|
|
def translate_jvp(jvpfun, fun, argnum): |
|
if jvpfun is None: |
|
return lambda g, ans, *a, **k: vspace(ans).zeros() |
|
elif jvpfun == 'same': |
|
return (lambda g, ans, *args, **kwargs: |
|
fun(*subval(args, argnum, g), **kwargs)) |
|
elif callable(jvpfun): |
|
return jvpfun |
|
else: |
|
raise Exception("Bad JVP '{}' for '{}'".format(jvpfun, fun.__name__)) |
|
|
|
def def_linear(fun): |
|
"""Flags that a function is linear wrt all args""" |
|
defjvp_argnum(fun, lambda argnum, g, ans, args, kwargs: |
|
fun(*subval(args, argnum, g), **kwargs)) |
|
|
|
|
|
|
|
def add_outgrads(prev_g_flagged, g): |
|
sparse = type(g) in sparse_object_types |
|
if prev_g_flagged: |
|
vs = vspace(g) |
|
prev_g, mutable = prev_g_flagged |
|
if mutable: |
|
if sparse: |
|
return sparse_add(vs, prev_g, g), True |
|
else: |
|
return vs.mut_add(prev_g, g), True |
|
else: |
|
if sparse: |
|
prev_g_mutable = vs.mut_add(None, prev_g) |
|
return sparse_add(vs, prev_g_mutable, g), True |
|
else: |
|
return vs.add(prev_g, g), True |
|
else: |
|
if sparse: |
|
return sparse_add(vspace(g), None, g), True |
|
else: |
|
return g, False |
|
|
|
def sum_outgrads(gs): |
|
return reduce(add_outgrads, gs, None)[0] |
|
|
|
@primitive |
|
def sparse_add(vs, x_prev, x_new): |
|
x_prev = x_prev if x_prev is not None else vs.zeros() |
|
return x_new.mut_add(x_prev) |
|
|
|
class VSpace(object): |
|
__slots__ = [] |
|
mappings = {} |
|
iscomplex = False |
|
def __init__(self, value): pass |
|
|
|
def zeros(self): assert False, repr(self) |
|
def ones(self): assert False, repr(self) |
|
def standard_basis(self): assert False, repr(self) |
|
def randn(self): assert False, repr(self) |
|
|
|
@primitive |
|
def mut_add(self, x_prev, x_new): |
|
x_prev = x_prev if x_prev is not None else self.zeros() |
|
return self._mut_add(x_prev, x_new) |
|
@primitive |
|
def add(self, x_prev, x_new): return self._add(x_prev, x_new) |
|
@primitive |
|
def scalar_mul(self, x, a): return self._scalar_mul(x, a) |
|
@primitive |
|
def inner_prod(self, x, y): return self._inner_prod(x, y) |
|
@primitive |
|
def covector(self, x): return self._covector(x) |
|
|
|
def _add(self, x, y): return x + y |
|
def _mut_add(self, x, y): x += y; return x |
|
def _scalar_mul(self, x, a): return x * a |
|
def _inner_prod(self, x, y): assert False |
|
def _covector(self, x): return x |
|
|
|
def __eq__(self, other): |
|
return type(self) == type(other) and self.__dict__ == other.__dict__ |
|
|
|
def __repr__(self): |
|
return "{}_{}".format(type(self).__name__, self.__dict__) |
|
|
|
@classmethod |
|
def register(cls, value_type, vspace_maker=None): |
|
if vspace_maker: |
|
VSpace.mappings[value_type] = vspace_maker |
|
else: |
|
VSpace.mappings[value_type] = cls |
|
|
|
def vspace(value): |
|
try: |
|
return VSpace.mappings[type(value)](value) |
|
except KeyError: |
|
if isbox(value): |
|
return vspace(getval(value)) |
|
else: |
|
raise TypeError("Can't find vector space for value {} of type {}. " |
|
"Valid types are {}".format( |
|
value, type(value), VSpace.mappings.keys())) |
|
|
|
class SparseBox(Box): |
|
__slots__ = [] |
|
class SparseObject(object): |
|
__slots__ = ['vs', 'mut_add'] |
|
def __init__(self, vs, mut_add): |
|
self.vs = vs |
|
self.mut_add = mut_add |
|
VSpace.register(SparseObject, lambda x : x.vs) |
|
SparseBox.register(SparseObject) |
|
sparse_object_types = {SparseObject, SparseBox} |
|
|
|
|
|
|
|
identity_vjp = lambda argnums, *args: lambda g: g |
|
defvjp(sparse_add, None, identity_vjp, identity_vjp) |
|
defvjp(func(VSpace.add ), None, identity_vjp, identity_vjp) |
|
defvjp(func(VSpace.mut_add), None, identity_vjp, identity_vjp) |
|
defvjp(func(VSpace.inner_prod), None, |
|
lambda ans, vs, x, y: lambda g: vs.covector(vs.scalar_mul(y, g)), |
|
lambda ans, vs, x, y: lambda g: vs.covector(vs.scalar_mul(x, g))) |
|
defvjp(func(VSpace.covector), None, |
|
lambda ans, vs, x: lambda g: vs.covector(g)) |
|
defvjp(func(VSpace.scalar_mul), None, |
|
lambda ans, vs, x, a: lambda g: vs.covector(vs.scalar_mul(vs.covector(g), a)), |
|
lambda ans, vs, x, a: lambda g: vs.inner_prod(g, vs.covector(x))) |
|
|
|
|
|
|
|
identity_jvp = lambda g, *args, **kwargs: g |
|
defjvp(sparse_add, None, identity_jvp, identity_jvp) |
|
defjvp(func(VSpace.mut_add), None, identity_jvp, identity_jvp) |
|
defjvp(func(VSpace.add), None, identity_jvp, identity_jvp) |
|
defjvp(func(VSpace.scalar_mul), None, 'same', 'same') |
|
defjvp(func(VSpace.inner_prod), None, 'same', 'same') |
|
defjvp(func(VSpace.covector), None, 'same') |
|
|
|
|
|
|
|
import warnings |
|
deprecated_defvjp_message = ''' |
|
The {} method is deprecated. See the update guide and tutorial: |
|
https://github.com/HIPS/autograd/blob/master/docs/updateguide.md |
|
https://github.com/HIPS/autograd/blob/master/docs/tutorial.md''' |
|
|
|
def deprecated_defvjp(primitive_fun): |
|
deprecation_msg = deprecated_defvjp_message.format('defvjp') |
|
vjpfuns = {} |
|
def defvjp_unstaged(vjpmaker, argnum=0): |
|
warnings.warn(deprecation_msg) |
|
def staged_vjpmaker(ans, *args, **kwargs): |
|
def vjp(g): |
|
vs, gvs = vspace(args[argnum]), vspace(g) |
|
return vjpmaker(g, ans, vs, gvs, *args, **kwargs) |
|
return vjp |
|
vjpfuns[argnum] = staged_vjpmaker |
|
argnums, vjpmakers = zip(*[(argnum, vjpfuns[argnum]) |
|
for argnum in sorted(vjpfuns.keys())]) |
|
defvjp(primitive_fun, *vjpmakers, argnums=argnums) |
|
return defvjp_unstaged |
|
|
|
def deprecated_defvjp_is_zero(primitive_fun): |
|
deprecation_msg = deprecated_defvjp_message.format('defvjp_is_zero') |
|
zero_vjps = [set()] |
|
def defvjp_is_zero(argnums=(0,)): |
|
warnings.warn(deprecation_msg) |
|
zero_vjps[0] |= set(argnums) |
|
nones = [None] * len(zero_vjps[0]) |
|
defvjp(primitive_fun, *nones, argnums=sorted(zero_vjps[0])) |
|
return defvjp_is_zero |
|
|
|
def deprecated_defgrad(primitive_fun): |
|
deprecation_msg = deprecated_defvjp_message.format('defgrad') |
|
gradfuns = {} |
|
def defgrad(gradfun, argnum=0): |
|
warnings.warn(deprecation_msg) |
|
gradfuns[argnum] = gradfun |
|
argnums, vjpmakers = zip(*[(argnum, gradfuns[argnum]) |
|
for argnum in sorted(gradfuns.keys())]) |
|
defvjp(primitive_fun, *vjpmakers, argnums=argnums) |
|
return defgrad |
|
|
|
primitive_ = primitive |
|
|
|
def primitive_with_deprecation_warnings(f_raw): |
|
f_wrapped = primitive_(f_raw) |
|
f_wrapped.defvjp = deprecated_defvjp(f_wrapped) |
|
f_wrapped.defvjp_is_zero = deprecated_defvjp_is_zero(f_wrapped) |
|
f_wrapped.defgrad = deprecated_defgrad(f_wrapped) |
|
return f_wrapped |
|
|
|
primitive = primitive_with_deprecation_warnings |
|
|