|
from __future__ import absolute_import |
|
import scipy.special |
|
import autograd.numpy as np |
|
from autograd.extend import primitive, defvjp, defjvp |
|
from autograd.numpy.numpy_vjps import unbroadcast_f, repeat_to_match_shape |
|
|
|
|
|
beta = primitive(scipy.special.beta) |
|
betainc = primitive(scipy.special.betainc) |
|
betaln = primitive(scipy.special.betaln) |
|
|
|
defvjp(beta, |
|
lambda ans, a, b: unbroadcast_f(a, lambda g: g * ans * (psi(a) - psi(a + b))), |
|
lambda ans, a, b: unbroadcast_f(b, lambda g: g * ans * (psi(b) - psi(a + b)))) |
|
defvjp(betainc, |
|
lambda ans, a, b, x: unbroadcast_f(x, lambda g: g * np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b)), |
|
argnums=[2]) |
|
defvjp(betaln, |
|
lambda ans, a, b: unbroadcast_f(a, lambda g: g * (psi(a) - psi(a + b))), |
|
lambda ans, a, b: unbroadcast_f(b, lambda g: g * (psi(b) - psi(a + b)))) |
|
|
|
|
|
polygamma = primitive(scipy.special.polygamma) |
|
psi = primitive(scipy.special.psi) |
|
digamma = primitive(scipy.special.digamma) |
|
gamma = primitive(scipy.special.gamma) |
|
gammaln = primitive(scipy.special.gammaln) |
|
gammainc = primitive(scipy.special.gammainc) |
|
gammaincc = primitive(scipy.special.gammaincc) |
|
gammasgn = primitive(scipy.special.gammasgn) |
|
rgamma = primitive(scipy.special.rgamma) |
|
multigammaln = primitive(scipy.special.multigammaln) |
|
|
|
defvjp(gammasgn, None) |
|
defvjp(polygamma, None, lambda ans, n, x: lambda g: g * polygamma(n + 1, x)) |
|
defvjp(psi, lambda ans, x: lambda g: g * polygamma(1, x)) |
|
defvjp(digamma, lambda ans, x: lambda g: g * polygamma(1, x)) |
|
defvjp(gamma, lambda ans, x: lambda g: g * ans * psi(x)) |
|
defvjp(gammaln, lambda ans, x: lambda g: g * psi(x)) |
|
defvjp(rgamma, lambda ans, x: lambda g: g * psi(x) / -gamma(x)) |
|
defvjp(multigammaln,lambda ans, a, d: lambda g: |
|
g * np.sum(digamma(np.expand_dims(a, -1) - np.arange(d)/2.), -1), |
|
None) |
|
|
|
def make_gammainc_vjp_arg1(sign): |
|
def gammainc_vjp_arg1(ans, a, x): |
|
coeffs = sign * np.exp(-x) * np.power(x, a - 1) / gamma(a) |
|
return unbroadcast_f(x, lambda g: g * coeffs) |
|
return gammainc_vjp_arg1 |
|
defvjp(gammainc, make_gammainc_vjp_arg1(1), argnums=[1]) |
|
defvjp(gammaincc, make_gammainc_vjp_arg1(-1), argnums=[1]) |
|
|
|
|
|
|
|
j0 = primitive(scipy.special.j0) |
|
y0 = primitive(scipy.special.y0) |
|
j1 = primitive(scipy.special.j1) |
|
y1 = primitive(scipy.special.y1) |
|
jn = primitive(scipy.special.jn) |
|
yn = primitive(scipy.special.yn) |
|
|
|
defvjp(j0,lambda ans, x: lambda g: -g * j1(x)) |
|
defvjp(y0,lambda ans, x: lambda g: -g * y1(x)) |
|
defvjp(j1,lambda ans, x: lambda g: g * (j0(x) - jn(2, x)) / 2.0) |
|
defvjp(y1,lambda ans, x: lambda g: g * (y0(x) - yn(2, x)) / 2.0) |
|
defvjp(jn, None, lambda ans, n, x: lambda g: g * (jn(n - 1, x) - jn(n + 1, x)) / 2.0) |
|
defvjp(yn, None, lambda ans, n, x: lambda g: g * (yn(n - 1, x) - yn(n + 1, x)) / 2.0) |
|
|
|
|
|
|
|
i0 = primitive(scipy.special.i0) |
|
i1 = primitive(scipy.special.i1) |
|
iv = primitive(scipy.special.iv) |
|
ive = primitive(scipy.special.ive) |
|
|
|
defvjp(i0, lambda ans, x: lambda g: g * i1(x)) |
|
defvjp(i1, lambda ans, x: lambda g: g * (i0(x) + iv(2, x)) / 2.0) |
|
defvjp(iv, None, lambda ans, n, x: lambda g: g * (iv(n - 1, x) + iv(n + 1, x)) / 2.0) |
|
defvjp(ive, None, lambda ans, n, x: lambda g: g * (ans * (n / x - np.sign(x)) + ive(n + 1, x))) |
|
|
|
|
|
inv_root_pi = 0.56418958354775627928 |
|
erf = primitive(scipy.special.erf) |
|
erfc = primitive(scipy.special.erfc) |
|
|
|
defvjp(erf, lambda ans, x: lambda g: 2.*g*inv_root_pi*np.exp(-x**2)) |
|
defvjp(erfc,lambda ans, x: lambda g: -2.*g*inv_root_pi*np.exp(-x**2)) |
|
|
|
|
|
|
|
root_pi = 1.7724538509055159 |
|
erfinv = primitive(scipy.special.erfinv) |
|
erfcinv = primitive(scipy.special.erfcinv) |
|
|
|
defvjp(erfinv,lambda ans, x: lambda g: g * root_pi / 2 * np.exp(erfinv(x)**2)) |
|
defvjp(erfcinv,lambda ans, x: lambda g: -g * root_pi / 2 * np.exp(erfcinv(x)**2)) |
|
|
|
|
|
logit = primitive(scipy.special.logit) |
|
expit = primitive(scipy.special.expit) |
|
|
|
defvjp(logit,lambda ans, x: lambda g: g / ( x * (1 - x))) |
|
defvjp(expit,lambda ans, x: lambda g: g * ans * (1 - ans)) |
|
|
|
|
|
logsumexp = primitive(scipy.special.logsumexp) |
|
|
|
def make_grad_logsumexp(ans, x, axis=None, b=1.0, keepdims=False): |
|
shape, dtype = np.shape(x), np.result_type(x) |
|
def vjp(g): |
|
g_repeated, _ = repeat_to_match_shape(g, shape, dtype, axis, keepdims) |
|
ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims) |
|
return g_repeated * b * np.exp(x - ans_repeated) |
|
return vjp |
|
|
|
defvjp(logsumexp, make_grad_logsumexp) |
|
|
|
def fwd_grad_logsumexp(g, ans, x, axis=None, b=1.0, keepdims=False): |
|
if not keepdims: |
|
if isinstance(axis, int): |
|
ans = np.expand_dims(ans, axis) |
|
elif isinstance(axis, tuple): |
|
for ax in sorted(axis): |
|
ans = np.expand_dims(ans, ax) |
|
return np.sum(g * b * np.exp(x - ans), axis=axis, keepdims=keepdims) |
|
|
|
defjvp(logsumexp, fwd_grad_logsumexp) |
|
|