File size: 6,260 Bytes
ab4488b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
from .util import subvals
from .extend import (Box, primitive, notrace_primitive, VSpace, vspace,
SparseObject, defvjp, defvjp_argnum, defjvp, defjvp_argnum)
isinstance_ = isinstance
isinstance = notrace_primitive(isinstance)
type_ = type
type = notrace_primitive(type)
tuple_, list_, dict_ = tuple, list, dict
@primitive
def container_take(A, idx):
return A[idx]
def grad_container_take(ans, A, idx):
return lambda g: container_untake(g, idx, vspace(A))
defvjp(container_take, grad_container_take)
defjvp(container_take, 'same')
class SequenceBox(Box):
__slots__ = []
__getitem__ = container_take
def __len__(self): return len(self._value)
def __add__(self, other): return sequence_extend_right(self, *other)
def __radd__(self, other): return sequence_extend_left(self, *other)
def __contains__(self, elt): return elt in self._value
def index(self, elt): return self._value.index(elt)
SequenceBox.register(tuple_)
SequenceBox.register(list_)
class DictBox(Box):
__slots__ = []
__getitem__= container_take
def __len__(self): return len(self._value)
def __iter__(self): return self._value.__iter__()
def __contains__(self, elt): return elt in self._value
def items(self): return list(self.iteritems())
def keys(self): return list(self.iterkeys())
def values(self): return list(self.itervalues())
def iteritems(self): return ((k, self[k]) for k in self)
def iterkeys(self): return iter(self)
def itervalues(self): return (self[k] for k in self)
def get(self, k, d=None): return self[k] if k in self else d
DictBox.register(dict_)
@primitive
def container_untake(x, idx, vs):
if isinstance(idx, slice):
accum = lambda result: [elt_vs._mut_add(a, b)
for elt_vs, a, b in zip(vs.shape[idx], result, x)]
else:
accum = lambda result: vs.shape[idx]._mut_add(result, x)
def mut_add(A):
return vs._subval(A, idx, accum(A[idx]))
return SparseObject(vs, mut_add)
defvjp(container_untake, lambda ans, x, idx, _:
lambda g: container_take(g, idx))
defjvp(container_untake, 'same')
@primitive
def sequence_extend_right(seq, *elts):
return seq + type(seq)(elts)
def grad_sequence_extend_right(argnum, ans, args, kwargs):
seq, elts = args[0], args[1:]
return lambda g: g[:len(seq)] if argnum == 0 else g[len(seq) + argnum - 1]
defvjp_argnum(sequence_extend_right, grad_sequence_extend_right)
@primitive
def sequence_extend_left(seq, *elts):
return type(seq)(elts) + seq
def grad_sequence_extend_left(argnum, ans, args, kwargs):
seq, elts = args[0], args[1:]
return lambda g: g[len(elts):] if argnum == 0 else g[argnum - 1]
defvjp_argnum(sequence_extend_left, grad_sequence_extend_left)
@primitive
def make_sequence(seq_type, *args):
return seq_type(args)
defvjp_argnum(make_sequence, lambda argnum, *args: lambda g: g[argnum - 1])
def fwd_grad_make_sequence(argnum, g, ans, seq_type, *args, **kwargs):
return container_untake(g, argnum-1, vspace(ans))
defjvp_argnum(make_sequence, fwd_grad_make_sequence)
class TupleMeta(type(tuple_)):
def __instancecheck__(self, instance):
return isinstance(instance, tuple_)
class tuple(tuple_, metaclass=TupleMeta):
def __new__(cls, xs):
return make_sequence(tuple_, *xs)
class ListMeta(type_):
def __instancecheck__(self, instance):
return isinstance(instance, list_)
class list(list_, metaclass=ListMeta):
def __new__(cls, xs):
return make_sequence(list_, *xs)
class DictMeta(type_):
def __instancecheck__(self, instance):
return isinstance(instance, dict_)
class dict(dict_, metaclass=DictMeta):
def __new__(cls, *args, **kwargs):
result = dict_(*args, **kwargs)
if result:
return _make_dict(result.keys(), list(result.values()))
return result
@primitive
def _make_dict(keys, vals):
return dict_(zip(keys, vals))
defvjp(_make_dict,
lambda ans, keys, vals: lambda g:
list(g[key] for key in keys), argnums=(1,))
class ContainerVSpace(VSpace):
def __init__(self, value):
self.shape = value
self.shape = self._map(vspace)
@property
def size(self): return sum(self._values(self._map(lambda vs: vs.size)))
def zeros(self): return self._map(lambda vs: vs.zeros())
def ones(self): return self._map(lambda vs: vs.ones())
def randn(self): return self._map(lambda vs: vs.randn())
def standard_basis(self):
zero = self.zeros()
for i, vs in self._kv_pairs(self.shape):
for x in vs.standard_basis():
yield self._subval(zero, i, x)
def _add(self, xs, ys):
return self._map(lambda vs, x, y: vs._add(x, y), xs, ys)
def _mut_add(self, xs, ys):
return self._map(lambda vs, x, y: vs._mut_add(x, y), xs, ys)
def _scalar_mul(self, xs, a):
return self._map(lambda vs, x: vs._scalar_mul(x, a), xs)
def _inner_prod(self, xs, ys):
return sum(self._values(self._map(lambda vs, x, y: vs._inner_prod(x, y), xs, ys)))
def _covector(self, xs):
return self._map(lambda vs, x: vs._covector(x), xs)
class SequenceVSpace(ContainerVSpace):
def _values(self, x): return x
def _kv_pairs(self, x): return enumerate(x)
def _map(self, f, *args):
return self.seq_type(map(f, self.shape, *args))
def _subval(self, xs, idx, x):
return self.seq_type(subvals(xs, [(idx, x)]))
class ListVSpace(SequenceVSpace): seq_type = list_
class TupleVSpace(SequenceVSpace): seq_type = tuple_
class DictVSpace(ContainerVSpace):
def _values(self, x): return x.values()
def _kv_pairs(self, x): return x.items()
def _map(self, f, *args):return {k: f(vs, *[x[k] for x in args])
for k, vs in self.shape.items()}
def _subval(self, xs, idx, x):
d = dict(xs.items())
d[idx] = x
return d
ListVSpace.register(list_)
TupleVSpace.register(tuple_)
DictVSpace.register(dict_)
class NamedTupleVSpace(SequenceVSpace):
def _map(self, f, *args):
return self.seq_type(*map(f, self.shape, *args))
def _subval(self, xs, idx, x):
return self.seq_type(*subvals(xs, [(idx, x)]))
|