|
from __future__ import division |
|
from functools import partial |
|
import scipy.linalg |
|
|
|
import autograd.numpy as anp |
|
from autograd.numpy.numpy_wrapper import wrap_namespace |
|
from autograd.extend import defvjp, defvjp_argnums, defjvp, defjvp_argnums |
|
|
|
wrap_namespace(scipy.linalg.__dict__, globals()) |
|
|
|
def _vjp_sqrtm(ans, A, disp=True, blocksize=64): |
|
assert disp, "sqrtm vjp not implemented for disp=False" |
|
ans_transp = anp.transpose(ans) |
|
def vjp(g): |
|
return anp.real(solve_sylvester(ans_transp, ans_transp, g)) |
|
return vjp |
|
defvjp(sqrtm, _vjp_sqrtm) |
|
|
|
def _flip(a, trans): |
|
if anp.iscomplexobj(a): |
|
return 'H' if trans in ('N', 0) else 'N' |
|
else: |
|
return 'T' if trans in ('N', 0) else 'N' |
|
|
|
def grad_solve_triangular(ans, a, b, trans=0, lower=False, **kwargs): |
|
tri = anp.tril if (lower ^ (_flip(a, trans) == 'N')) else anp.triu |
|
transpose = lambda x: x if _flip(a, trans) != 'N' else x.T |
|
al2d = lambda x: x if x.ndim > 1 else x[...,None] |
|
def vjp(g): |
|
v = al2d(solve_triangular(a, g, trans=_flip(a, trans), lower=lower)) |
|
return -transpose(tri(anp.dot(v, al2d(ans).T))) |
|
return vjp |
|
|
|
defvjp(solve_triangular, |
|
grad_solve_triangular, |
|
lambda ans, a, b, trans=0, lower=False, **kwargs: |
|
lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower)) |
|
|
|
def grad_solve_banded(argnum, ans, l_and_u, a, b): |
|
|
|
updim = lambda x: x if x.ndim == a.ndim else x[...,None] |
|
|
|
def transpose_banded(l_and_u, a): |
|
|
|
|
|
|
|
|
|
num_rows = a.shape[0] |
|
|
|
shifts = anp.arange(-l_and_u[1], l_and_u[0]+1) |
|
|
|
T_a = anp.roll(a[:1, :], shifts[0]) |
|
for rr in range(1, num_rows): |
|
T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr:rr+1, :], shifts[rr]))]) |
|
T_a = anp.flipud(T_a) |
|
|
|
T_l_and_u = anp.flip(l_and_u) |
|
|
|
return T_l_and_u, T_a |
|
|
|
def banded_dot(l_and_u, uu, vv): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
banded_uv = anp.ravel(uu)*anp.ravel(vv) |
|
|
|
|
|
for rr in range(1, l_and_u[0]+1): |
|
banded_uv_rr = anp.hstack([anp.ravel(uu)[rr:]*anp.ravel(vv)[:-rr], anp.zeros(rr)]) |
|
banded_uv = anp.vstack([banded_uv, banded_uv_rr]) |
|
|
|
|
|
for rr in range(1, l_and_u[1]+1): |
|
banded_uv_rr = anp.hstack([anp.zeros(rr), anp.ravel(uu)[:-rr]*anp.ravel(vv)[rr:]]) |
|
banded_uv = anp.vstack([banded_uv_rr, banded_uv]) |
|
|
|
return(banded_uv) |
|
|
|
T_l_and_u, T_a = transpose_banded(l_and_u, a) |
|
|
|
if argnum == 1: |
|
return lambda g: -banded_dot(l_and_u, updim(solve_banded(T_l_and_u, T_a, g)), anp.transpose(updim(ans))) |
|
elif argnum == 2: |
|
return lambda g: solve_banded(T_l_and_u, T_a, g) |
|
|
|
defvjp(solve_banded, |
|
partial(grad_solve_banded, 1), |
|
partial(grad_solve_banded, 2), |
|
argnums=[1, 2]) |
|
|
|
def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64): |
|
assert disp, "sqrtm jvp not implemented for disp=False" |
|
return solve_sylvester(ans, ans, dA) |
|
defjvp(sqrtm, _jvp_sqrtm) |
|
|
|
def _jvp_sylvester(argnums, dms, ans, args, _): |
|
a, b, q = args |
|
if 0 in argnums: |
|
da = dms[0] |
|
db = dms[1] if 1 in argnums else 0 |
|
else: |
|
da = 0 |
|
db = dms[0] if 1 in argnums else 0 |
|
dq = dms[-1] if 2 in argnums else 0 |
|
rhs = dq - anp.dot(da, ans) - anp.dot(ans, db) |
|
return solve_sylvester(a, b, rhs) |
|
defjvp_argnums(solve_sylvester, _jvp_sylvester) |
|
|
|
def _vjp_sylvester(argnums, ans, args, _): |
|
a, b, q = args |
|
def vjp(g): |
|
vjps = [] |
|
q_vjp = solve_sylvester(anp.transpose(a), anp.transpose(b), g) |
|
if 0 in argnums: vjps.append(-anp.dot(q_vjp, anp.transpose(ans))) |
|
if 1 in argnums: vjps.append(-anp.dot(anp.transpose(ans), q_vjp)) |
|
if 2 in argnums: vjps.append(q_vjp) |
|
return tuple(vjps) |
|
return vjp |
|
defvjp_argnums(solve_sylvester, _vjp_sylvester) |
|
|