File size: 6,260 Bytes
ab4488b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from .util import subvals
from .extend import (Box, primitive, notrace_primitive, VSpace, vspace,
                     SparseObject, defvjp, defvjp_argnum, defjvp, defjvp_argnum)

isinstance_ = isinstance
isinstance = notrace_primitive(isinstance)

type_ = type
type = notrace_primitive(type)

tuple_, list_, dict_ = tuple, list, dict

@primitive
def container_take(A, idx):
    return A[idx]
def grad_container_take(ans, A, idx):
    return lambda g: container_untake(g, idx, vspace(A))
defvjp(container_take, grad_container_take)
defjvp(container_take, 'same')

class SequenceBox(Box):
    __slots__ = []
    __getitem__ = container_take
    def __len__(self): return len(self._value)
    def __add__(self, other): return sequence_extend_right(self, *other)
    def __radd__(self, other): return sequence_extend_left(self, *other)
    def __contains__(self, elt): return elt in self._value
    def index(self, elt): return self._value.index(elt)
SequenceBox.register(tuple_)
SequenceBox.register(list_)

class DictBox(Box):
    __slots__ = []
    __getitem__= container_take
    def __len__(self): return len(self._value)
    def __iter__(self): return self._value.__iter__()
    def __contains__(self, elt): return elt in self._value
    def items(self): return list(self.iteritems())
    def keys(self): return list(self.iterkeys())
    def values(self): return list(self.itervalues())
    def iteritems(self): return ((k, self[k]) for k in self)
    def iterkeys(self): return iter(self)
    def itervalues(self): return (self[k] for k in self)
    def get(self, k, d=None): return self[k] if k in self else d
DictBox.register(dict_)

@primitive
def container_untake(x, idx, vs):
    if isinstance(idx, slice):
        accum = lambda result: [elt_vs._mut_add(a, b)
                                for elt_vs, a, b in zip(vs.shape[idx], result, x)]
    else:
        accum = lambda result: vs.shape[idx]._mut_add(result, x)
    def mut_add(A):
        return vs._subval(A, idx, accum(A[idx]))
    return SparseObject(vs, mut_add)
defvjp(container_untake, lambda ans, x, idx, _:
       lambda g: container_take(g, idx))
defjvp(container_untake, 'same')

@primitive
def sequence_extend_right(seq, *elts):
    return seq + type(seq)(elts)
def grad_sequence_extend_right(argnum, ans, args, kwargs):
    seq, elts = args[0], args[1:]
    return lambda g: g[:len(seq)] if argnum == 0 else g[len(seq) + argnum - 1]
defvjp_argnum(sequence_extend_right, grad_sequence_extend_right)

@primitive
def sequence_extend_left(seq, *elts):
    return type(seq)(elts) + seq
def grad_sequence_extend_left(argnum, ans, args, kwargs):
    seq, elts = args[0], args[1:]
    return lambda g: g[len(elts):] if argnum == 0 else g[argnum - 1]
defvjp_argnum(sequence_extend_left, grad_sequence_extend_left)

@primitive
def make_sequence(seq_type, *args):
    return seq_type(args)
defvjp_argnum(make_sequence, lambda argnum, *args: lambda g: g[argnum - 1])

def fwd_grad_make_sequence(argnum, g, ans, seq_type, *args, **kwargs):
    return container_untake(g, argnum-1, vspace(ans))

defjvp_argnum(make_sequence, fwd_grad_make_sequence)


class TupleMeta(type(tuple_)):
    def __instancecheck__(self, instance):
        return isinstance(instance, tuple_)

class tuple(tuple_, metaclass=TupleMeta):
    def __new__(cls, xs):
        return make_sequence(tuple_, *xs)

class ListMeta(type_):
    def __instancecheck__(self, instance):
        return isinstance(instance, list_)
class list(list_, metaclass=ListMeta):
    def __new__(cls, xs):
        return make_sequence(list_, *xs)

class DictMeta(type_):
    def __instancecheck__(self, instance):
        return isinstance(instance, dict_)
class dict(dict_, metaclass=DictMeta):
    def __new__(cls, *args, **kwargs):
        result = dict_(*args, **kwargs)
        if result:
            return _make_dict(result.keys(), list(result.values()))
        return result

@primitive
def _make_dict(keys, vals):
    return dict_(zip(keys, vals))
defvjp(_make_dict,
       lambda ans, keys, vals: lambda g:
       list(g[key] for key in keys), argnums=(1,))

class ContainerVSpace(VSpace):
    def __init__(self, value):
        self.shape = value
        self.shape = self._map(vspace)

    @property
    def size(self): return sum(self._values(self._map(lambda vs: vs.size)))
    def zeros(self): return self._map(lambda vs: vs.zeros())
    def ones(self):  return self._map(lambda vs: vs.ones())
    def randn(self): return self._map(lambda vs: vs.randn())
    def standard_basis(self):
        zero = self.zeros()
        for i, vs in self._kv_pairs(self.shape):
            for x in vs.standard_basis():
                yield self._subval(zero, i, x)
    def _add(self, xs, ys):
        return self._map(lambda vs, x, y: vs._add(x, y), xs, ys)
    def _mut_add(self, xs, ys):
        return self._map(lambda vs, x, y: vs._mut_add(x, y), xs, ys)
    def _scalar_mul(self, xs, a):
        return self._map(lambda vs, x: vs._scalar_mul(x, a), xs)
    def _inner_prod(self, xs, ys):
        return sum(self._values(self._map(lambda vs, x, y: vs._inner_prod(x, y), xs, ys)))
    def _covector(self, xs):
        return self._map(lambda vs, x: vs._covector(x), xs)

class SequenceVSpace(ContainerVSpace):
    def _values(self, x): return x
    def _kv_pairs(self, x): return enumerate(x)
    def _map(self, f, *args):
        return self.seq_type(map(f, self.shape, *args))
    def _subval(self, xs, idx, x):
        return self.seq_type(subvals(xs, [(idx, x)]))

class ListVSpace(SequenceVSpace):  seq_type = list_
class TupleVSpace(SequenceVSpace): seq_type = tuple_
class DictVSpace(ContainerVSpace):
    def _values(self, x):   return x.values()
    def _kv_pairs(self, x): return x.items()
    def _map(self, f, *args):return {k: f(vs, *[x[k] for x in args])
                                     for k, vs in self.shape.items()}
    def _subval(self, xs, idx, x):
        d = dict(xs.items())
        d[idx] = x
        return d

ListVSpace.register(list_)
TupleVSpace.register(tuple_)
DictVSpace.register(dict_)

class NamedTupleVSpace(SequenceVSpace):
    def _map(self, f, *args):
        return self.seq_type(*map(f, self.shape, *args))
    def _subval(self, xs, idx, x):
        return self.seq_type(*subvals(xs, [(idx, x)]))