GotoUsuke's picture
Upload folder using huggingface_hub
ab4488b verified
from __future__ import absolute_import
from functools import partial
import numpy.linalg as npla
from .numpy_wrapper import wrap_namespace
from . import numpy_wrapper as anp
from autograd.extend import defvjp, defjvp
wrap_namespace(npla.__dict__, globals())
# Some formulas are from
# "An extended collection of matrix derivative results
# for forward and reverse mode algorithmic differentiation"
# by Mike Giles
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
# transpose by swapping last two dimensions
def T(x): return anp.swapaxes(x, -1, -2)
_dot = partial(anp.einsum, '...ij,...jk->...ik')
# batched diag
_diag = lambda a: anp.eye(a.shape[-1])*a
# batched diagonal, similar to matrix_diag in tensorflow
def _matrix_diag(a):
reps = anp.array(a.shape)
reps[:-1] = 1
reps[-1] = a.shape[-1]
newshape = list(a.shape) + [a.shape[-1]]
return _diag(anp.tile(a, reps).reshape(newshape))
# add two dimensions to the end of x
def add2d(x): return anp.reshape(x, anp.shape(x) + (1, 1))
defvjp(det, lambda ans, x: lambda g: add2d(g) * add2d(ans) * T(inv(x)))
defvjp(slogdet, lambda ans, x: lambda g: add2d(g[1]) * T(inv(x)))
def grad_inv(ans, x):
return lambda g: -_dot(_dot(T(ans), g), T(ans))
defvjp(inv, grad_inv)
def grad_pinv(ans, x):
# https://mathoverflow.net/questions/25778/analytical-formula-for-numerical-derivative-of-the-matrix-pseudo-inverse
return lambda g: T(
-_dot(_dot(ans, T(g)), ans)
+ _dot(_dot(_dot(ans, T(ans)), g), anp.eye(x.shape[-2]) - _dot(x,ans))
+ _dot(_dot(_dot(anp.eye(ans.shape[-2]) - _dot(ans,x), g), T(ans)), ans)
)
defvjp(pinv, grad_pinv)
def grad_solve(argnum, ans, a, b):
updim = lambda x: x if x.ndim == a.ndim else x[...,None]
if argnum == 0:
return lambda g: -_dot(updim(solve(T(a), g)), T(updim(ans)))
else:
return lambda g: solve(T(a), g)
defvjp(solve, partial(grad_solve, 0), partial(grad_solve, 1))
def norm_vjp(ans, x, ord=None, axis=None):
def check_implemented():
matrix_norm = (x.ndim == 2 and axis is None) or isinstance(axis, tuple)
if matrix_norm:
if not (ord is None or ord == 'fro' or ord == 'nuc'):
raise NotImplementedError('Gradient of matrix norm not '
'implemented for ord={}'.format(ord))
elif not (ord is None or ord > 1):
raise NotImplementedError('Gradient of norm not '
'implemented for ord={}'.format(ord))
if axis is None:
expand = lambda a: a
elif isinstance(axis, tuple):
row_axis, col_axis = axis
if row_axis > col_axis:
row_axis = row_axis - 1
expand = lambda a: anp.expand_dims(anp.expand_dims(a,
row_axis), col_axis)
else:
expand = lambda a: anp.expand_dims(a, axis=axis)
if ord == 'nuc':
if axis is None:
roll = lambda a: a
unroll = lambda a: a
else:
row_axis, col_axis = axis
if row_axis > col_axis:
row_axis = row_axis - 1
# Roll matrix axes to the back
roll = lambda a: anp.rollaxis(anp.rollaxis(a, col_axis, a.ndim),
row_axis, a.ndim-1)
# Roll matrix axes to their original position
unroll = lambda a: anp.rollaxis(anp.rollaxis(a, a.ndim-2, row_axis),
a.ndim-1, col_axis)
check_implemented()
def vjp(g):
if ord in (None, 2, 'fro'):
return expand(g / ans) * x
elif ord == 'nuc':
x_rolled = roll(x)
u, s, vt = svd(x_rolled, full_matrices=False)
uvt_rolled = _dot(u, vt)
# Roll the matrix axes back to their correct positions
uvt = unroll(uvt_rolled)
g = expand(g)
return g * uvt
else:
# see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
return expand(g / ans**(ord-1)) * x * anp.abs(x)**(ord-2)
return vjp
defvjp(norm, norm_vjp)
def norm_jvp(g, ans, x, ord=None, axis=None):
def check_implemented():
matrix_norm = (x.ndim == 2 and axis is None) or isinstance(axis, tuple)
if matrix_norm:
if not (ord is None or ord == 'fro' or ord == 'nuc'):
raise NotImplementedError('Gradient of matrix norm not '
'implemented for ord={}'.format(ord))
elif not (ord is None or ord > 1):
raise NotImplementedError('Gradient of norm not '
'implemented for ord={}'.format(ord))
if axis is None:
contract = lambda a: anp.sum(a)
else:
contract = partial(anp.sum, axis=axis)
if ord == 'nuc':
if axis is None:
roll = lambda a: a
unroll = lambda a: a
else:
row_axis, col_axis = axis
if row_axis > col_axis:
row_axis = row_axis - 1
# Roll matrix axes to the back
roll = lambda a: anp.rollaxis(anp.rollaxis(a, col_axis, a.ndim),
row_axis, a.ndim-1)
# Roll matrix axes to their original position
unroll = lambda a: anp.rollaxis(anp.rollaxis(a, a.ndim-2, row_axis),
a.ndim-1, col_axis)
check_implemented()
if ord in (None, 2, 'fro'):
return contract(g * x) / ans
elif ord == 'nuc':
x_rolled = roll(x)
u, s, vt = svd(x_rolled, full_matrices=False)
uvt_rolled = _dot(u, vt)
# Roll the matrix axes back to their correct positions
uvt = unroll(uvt_rolled)
return contract(g * uvt)
else:
# see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
return contract(g * x * anp.abs(x)**(ord-2)) / ans**(ord-1)
defjvp(norm, norm_jvp)
def grad_eigh(ans, x, UPLO='L'):
"""Gradient for eigenvalues and vectors of a symmetric matrix."""
N = x.shape[-1]
w, v = ans # Eigenvalues, eigenvectors.
vc = anp.conj(v)
def vjp(g):
wg, vg = g # Gradient w.r.t. eigenvalues, eigenvectors.
w_repeated = anp.repeat(w[..., anp.newaxis], N, axis=-1)
# Eigenvalue part
vjp_temp = _dot(vc * wg[..., anp.newaxis, :], T(v))
# Add eigenvector part only if non-zero backward signal is present.
# This can avoid NaN results for degenerate cases if the function depends
# on the eigenvalues only.
if anp.any(vg):
off_diag = anp.ones((N, N)) - anp.eye(N)
F = off_diag / (T(w_repeated) - w_repeated + anp.eye(N))
vjp_temp += _dot(_dot(vc, F * _dot(T(v), vg)), T(v))
# eigh always uses only the lower or the upper part of the matrix
# we also have to make sure broadcasting works
reps = anp.array(x.shape)
reps[-2:] = 1
if UPLO == 'L':
tri = anp.tile(anp.tril(anp.ones(N), -1), reps)
elif UPLO == 'U':
tri = anp.tile(anp.triu(anp.ones(N), 1), reps)
return anp.real(vjp_temp)*anp.eye(vjp_temp.shape[-1]) + \
(vjp_temp + anp.conj(T(vjp_temp))) * tri
return vjp
defvjp(eigh, grad_eigh)
# https://arxiv.org/pdf/1701.00392.pdf Eq(4.77)
# Note the formula from Sec3.1 in https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf is incomplete
def grad_eig(ans, x):
"""Gradient of a general square (complex valued) matrix"""
e, u = ans # eigenvalues as 1d array, eigenvectors in columns
n = e.shape[-1]
def vjp(g):
ge, gu = g
ge = _matrix_diag(ge)
f = 1/(e[..., anp.newaxis, :] - e[..., :, anp.newaxis] + 1.e-20)
f -= _diag(f)
ut = anp.swapaxes(u, -1, -2)
r1 = f * _dot(ut, gu)
r2 = -f * (_dot(_dot(ut, anp.conj(u)), anp.real(_dot(ut,gu)) * anp.eye(n)))
r = _dot(_dot(inv(ut), ge + r1 + r2), ut)
if not anp.iscomplexobj(x):
r = anp.real(r)
# the derivative is still complex for real input (imaginary delta is allowed), real output
# but the derivative should be real in real input case when imaginary delta is forbidden
return r
return vjp
defvjp(eig, grad_eig)
def grad_cholesky(L, A):
# Based on Iain Murray's note http://arxiv.org/abs/1602.07527
# scipy's dtrtrs wrapper, solve_triangular, doesn't broadcast along leading
# dimensions, so we just call a generic LU solve instead of directly using
# backsubstitution (also, we factor twice...)
solve_trans = lambda a, b: solve(T(a), b)
phi = lambda X: anp.tril(X) / (1. + anp.eye(X.shape[-1]))
def conjugate_solve(L, X):
# X -> L^{-T} X L^{-1}
return solve_trans(L, T(solve_trans(L, T(X))))
def vjp(g):
S = conjugate_solve(L, phi(anp.einsum('...ki,...kj->...ij', L, g)))
return (S + T(S)) / 2.
return vjp
defvjp(cholesky, grad_cholesky)
# https://j-towns.github.io/papers/svd-derivative.pdf
# https://arxiv.org/abs/1909.02659
def grad_svd(usv_, a, full_matrices=True, compute_uv=True):
def vjp(g):
usv = usv_
if not compute_uv:
s = usv
# Need U and V so do the whole svd anyway...
usv = svd(a, full_matrices=False)
u = usv[0]
v = anp.conj(T(usv[2]))
return _dot(anp.conj(u) * g[..., anp.newaxis, :], T(v))
elif full_matrices:
raise NotImplementedError(
"Gradient of svd not implemented for full_matrices=True")
else:
u = usv[0]
s = usv[1]
v = anp.conj(T(usv[2]))
m, n = a.shape[-2:]
k = anp.min((m, n))
# broadcastable identity array with shape (1, 1, ..., 1, k, k)
i = anp.reshape(anp.eye(k), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (k, k))))
f = 1 / (s[..., anp.newaxis, :]**2 - s[..., :, anp.newaxis]**2 + i)
gu = g[0]
gs = g[1]
gv = anp.conj(T(g[2]))
utgu = _dot(T(u), gu)
vtgv = _dot(T(v), gv)
t1 = (f * (utgu - anp.conj(T(utgu)))) * s[..., anp.newaxis, :]
t1 = t1 + i * gs[..., :, anp.newaxis]
t1 = t1 + s[..., :, anp.newaxis] * (f * (vtgv - anp.conj(T(vtgv))))
if anp.iscomplexobj(u):
t1 = t1 + 1j*anp.imag(_diag(utgu)) / s[..., anp.newaxis, :]
t1 = _dot(_dot(anp.conj(u), t1), T(v))
if m < n:
i_minus_vvt = (anp.reshape(anp.eye(n), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (n, n)))) -
_dot(v, anp.conj(T(v))))
t1 = t1 + anp.conj(_dot(_dot(u / s[..., anp.newaxis, :], T(gv)), i_minus_vvt))
return t1
elif m == n:
return t1
elif m > n:
i_minus_uut = (anp.reshape(anp.eye(m), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (m, m)))) -
_dot(u, anp.conj(T(u))))
t1 = t1 + T(_dot(_dot(v/s[..., anp.newaxis, :], T(gu)), i_minus_uut) )
return t1
return vjp
defvjp(svd, grad_svd)