File size: 4,240 Bytes
ab4488b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import division
from functools import partial
import scipy.linalg

import autograd.numpy as anp
from autograd.numpy.numpy_wrapper import wrap_namespace
from autograd.extend import defvjp, defvjp_argnums, defjvp, defjvp_argnums

wrap_namespace(scipy.linalg.__dict__, globals())  # populates module namespace

def _vjp_sqrtm(ans, A, disp=True, blocksize=64):
    assert disp, "sqrtm vjp not implemented for disp=False"
    ans_transp = anp.transpose(ans)
    def vjp(g):
        return anp.real(solve_sylvester(ans_transp, ans_transp, g))
    return vjp
defvjp(sqrtm, _vjp_sqrtm)

def _flip(a, trans):
    if anp.iscomplexobj(a):
        return 'H' if trans in ('N', 0) else 'N'
    else:
        return 'T' if trans in ('N', 0) else 'N'

def grad_solve_triangular(ans, a, b, trans=0, lower=False, **kwargs):
    tri = anp.tril if (lower ^ (_flip(a, trans) == 'N')) else anp.triu
    transpose = lambda x: x if _flip(a, trans) != 'N' else x.T
    al2d = lambda x: x if x.ndim > 1 else x[...,None]
    def vjp(g):
        v = al2d(solve_triangular(a, g, trans=_flip(a, trans), lower=lower))
        return -transpose(tri(anp.dot(v, al2d(ans).T)))
    return vjp

defvjp(solve_triangular,
       grad_solve_triangular,
       lambda ans, a, b, trans=0, lower=False, **kwargs:
       lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower))

def grad_solve_banded(argnum, ans, l_and_u, a, b):

    updim = lambda x: x if x.ndim == a.ndim else x[...,None]

    def transpose_banded(l_and_u, a):

        # Compute the transpose of a banded matrix.
        # The transpose is itself a banded matrix.

        num_rows = a.shape[0]

        shifts = anp.arange(-l_and_u[1], l_and_u[0]+1)

        T_a = anp.roll(a[:1, :], shifts[0])
        for rr in range(1, num_rows):
            T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr:rr+1, :], shifts[rr]))])
        T_a = anp.flipud(T_a)

        T_l_and_u = anp.flip(l_and_u)

        return T_l_and_u, T_a

    def banded_dot(l_and_u, uu, vv):

        # Compute tensor product of vectors uu and vv.
        # Tensor product elements are resticted to the bands specified by l_and_u.

        # TODO: replace the brute-force ravel() by smarter dimension handeling of uu and vv

        # main diagonal
        banded_uv = anp.ravel(uu)*anp.ravel(vv)

        # stack below the sub-diagonals
        for rr in range(1, l_and_u[0]+1):
            banded_uv_rr = anp.hstack([anp.ravel(uu)[rr:]*anp.ravel(vv)[:-rr], anp.zeros(rr)])
            banded_uv = anp.vstack([banded_uv, banded_uv_rr])

        # stack above the sup-diagonals
        for rr in range(1, l_and_u[1]+1):
            banded_uv_rr = anp.hstack([anp.zeros(rr), anp.ravel(uu)[:-rr]*anp.ravel(vv)[rr:]])
            banded_uv = anp.vstack([banded_uv_rr, banded_uv])

        return(banded_uv)

    T_l_and_u, T_a = transpose_banded(l_and_u, a)

    if argnum == 1:
        return lambda g: -banded_dot(l_and_u, updim(solve_banded(T_l_and_u, T_a, g)), anp.transpose(updim(ans)))
    elif argnum == 2:
        return lambda g: solve_banded(T_l_and_u, T_a, g)

defvjp(solve_banded,
       partial(grad_solve_banded, 1),
       partial(grad_solve_banded, 2),
       argnums=[1, 2])

def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64):
    assert disp, "sqrtm jvp not implemented for disp=False"
    return solve_sylvester(ans, ans, dA)
defjvp(sqrtm, _jvp_sqrtm)

def _jvp_sylvester(argnums, dms, ans, args, _):
    a, b, q = args
    if 0 in argnums:
        da = dms[0]
        db = dms[1] if 1 in argnums else 0
    else:
        da = 0
        db = dms[0] if 1 in argnums else 0
    dq = dms[-1] if 2 in argnums else 0
    rhs = dq - anp.dot(da, ans) - anp.dot(ans, db)
    return solve_sylvester(a, b, rhs)
defjvp_argnums(solve_sylvester, _jvp_sylvester)

def _vjp_sylvester(argnums, ans, args, _):
    a, b, q = args
    def vjp(g):
        vjps = []
        q_vjp = solve_sylvester(anp.transpose(a), anp.transpose(b), g)
        if 0 in argnums: vjps.append(-anp.dot(q_vjp, anp.transpose(ans)))
        if 1 in argnums: vjps.append(-anp.dot(anp.transpose(ans), q_vjp))
        if 2 in argnums: vjps.append(q_vjp)
        return tuple(vjps)
    return vjp
defvjp_argnums(solve_sylvester, _vjp_sylvester)