from .util import subvals from .extend import (Box, primitive, notrace_primitive, VSpace, vspace, SparseObject, defvjp, defvjp_argnum, defjvp, defjvp_argnum) isinstance_ = isinstance isinstance = notrace_primitive(isinstance) type_ = type type = notrace_primitive(type) tuple_, list_, dict_ = tuple, list, dict @primitive def container_take(A, idx): return A[idx] def grad_container_take(ans, A, idx): return lambda g: container_untake(g, idx, vspace(A)) defvjp(container_take, grad_container_take) defjvp(container_take, 'same') class SequenceBox(Box): __slots__ = [] __getitem__ = container_take def __len__(self): return len(self._value) def __add__(self, other): return sequence_extend_right(self, *other) def __radd__(self, other): return sequence_extend_left(self, *other) def __contains__(self, elt): return elt in self._value def index(self, elt): return self._value.index(elt) SequenceBox.register(tuple_) SequenceBox.register(list_) class DictBox(Box): __slots__ = [] __getitem__= container_take def __len__(self): return len(self._value) def __iter__(self): return self._value.__iter__() def __contains__(self, elt): return elt in self._value def items(self): return list(self.iteritems()) def keys(self): return list(self.iterkeys()) def values(self): return list(self.itervalues()) def iteritems(self): return ((k, self[k]) for k in self) def iterkeys(self): return iter(self) def itervalues(self): return (self[k] for k in self) def get(self, k, d=None): return self[k] if k in self else d DictBox.register(dict_) @primitive def container_untake(x, idx, vs): if isinstance(idx, slice): accum = lambda result: [elt_vs._mut_add(a, b) for elt_vs, a, b in zip(vs.shape[idx], result, x)] else: accum = lambda result: vs.shape[idx]._mut_add(result, x) def mut_add(A): return vs._subval(A, idx, accum(A[idx])) return SparseObject(vs, mut_add) defvjp(container_untake, lambda ans, x, idx, _: lambda g: container_take(g, idx)) defjvp(container_untake, 'same') @primitive def sequence_extend_right(seq, *elts): return seq + type(seq)(elts) def grad_sequence_extend_right(argnum, ans, args, kwargs): seq, elts = args[0], args[1:] return lambda g: g[:len(seq)] if argnum == 0 else g[len(seq) + argnum - 1] defvjp_argnum(sequence_extend_right, grad_sequence_extend_right) @primitive def sequence_extend_left(seq, *elts): return type(seq)(elts) + seq def grad_sequence_extend_left(argnum, ans, args, kwargs): seq, elts = args[0], args[1:] return lambda g: g[len(elts):] if argnum == 0 else g[argnum - 1] defvjp_argnum(sequence_extend_left, grad_sequence_extend_left) @primitive def make_sequence(seq_type, *args): return seq_type(args) defvjp_argnum(make_sequence, lambda argnum, *args: lambda g: g[argnum - 1]) def fwd_grad_make_sequence(argnum, g, ans, seq_type, *args, **kwargs): return container_untake(g, argnum-1, vspace(ans)) defjvp_argnum(make_sequence, fwd_grad_make_sequence) class TupleMeta(type(tuple_)): def __instancecheck__(self, instance): return isinstance(instance, tuple_) class tuple(tuple_, metaclass=TupleMeta): def __new__(cls, xs): return make_sequence(tuple_, *xs) class ListMeta(type_): def __instancecheck__(self, instance): return isinstance(instance, list_) class list(list_, metaclass=ListMeta): def __new__(cls, xs): return make_sequence(list_, *xs) class DictMeta(type_): def __instancecheck__(self, instance): return isinstance(instance, dict_) class dict(dict_, metaclass=DictMeta): def __new__(cls, *args, **kwargs): result = dict_(*args, **kwargs) if result: return _make_dict(result.keys(), list(result.values())) return result @primitive def _make_dict(keys, vals): return dict_(zip(keys, vals)) defvjp(_make_dict, lambda ans, keys, vals: lambda g: list(g[key] for key in keys), argnums=(1,)) class ContainerVSpace(VSpace): def __init__(self, value): self.shape = value self.shape = self._map(vspace) @property def size(self): return sum(self._values(self._map(lambda vs: vs.size))) def zeros(self): return self._map(lambda vs: vs.zeros()) def ones(self): return self._map(lambda vs: vs.ones()) def randn(self): return self._map(lambda vs: vs.randn()) def standard_basis(self): zero = self.zeros() for i, vs in self._kv_pairs(self.shape): for x in vs.standard_basis(): yield self._subval(zero, i, x) def _add(self, xs, ys): return self._map(lambda vs, x, y: vs._add(x, y), xs, ys) def _mut_add(self, xs, ys): return self._map(lambda vs, x, y: vs._mut_add(x, y), xs, ys) def _scalar_mul(self, xs, a): return self._map(lambda vs, x: vs._scalar_mul(x, a), xs) def _inner_prod(self, xs, ys): return sum(self._values(self._map(lambda vs, x, y: vs._inner_prod(x, y), xs, ys))) def _covector(self, xs): return self._map(lambda vs, x: vs._covector(x), xs) class SequenceVSpace(ContainerVSpace): def _values(self, x): return x def _kv_pairs(self, x): return enumerate(x) def _map(self, f, *args): return self.seq_type(map(f, self.shape, *args)) def _subval(self, xs, idx, x): return self.seq_type(subvals(xs, [(idx, x)])) class ListVSpace(SequenceVSpace): seq_type = list_ class TupleVSpace(SequenceVSpace): seq_type = tuple_ class DictVSpace(ContainerVSpace): def _values(self, x): return x.values() def _kv_pairs(self, x): return x.items() def _map(self, f, *args):return {k: f(vs, *[x[k] for x in args]) for k, vs in self.shape.items()} def _subval(self, xs, idx, x): d = dict(xs.items()) d[idx] = x return d ListVSpace.register(list_) TupleVSpace.register(tuple_) DictVSpace.register(dict_) class NamedTupleVSpace(SequenceVSpace): def _map(self, f, *args): return self.seq_type(*map(f, self.shape, *args)) def _subval(self, xs, idx, x): return self.seq_type(*subvals(xs, [(idx, x)]))