from __future__ import absolute_import | |
import scipy.stats | |
import autograd.numpy as np | |
from autograd.scipy.special import digamma | |
from autograd.extend import primitive, defvjp | |
rvs = primitive(scipy.stats.dirichlet.rvs) | |
pdf = primitive(scipy.stats.dirichlet.pdf) | |
logpdf = primitive(scipy.stats.dirichlet.logpdf) | |
defvjp(logpdf,lambda ans, x, alpha: lambda g: | |
g * (alpha - 1) / x, | |
lambda ans, x, alpha: lambda g: | |
g * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x))) | |
# Same as log pdf, but multiplied by the pdf (ans). | |
defvjp(pdf,lambda ans, x, alpha: lambda g: | |
g * ans * (alpha - 1) / x, | |
lambda ans, x, alpha: lambda g: | |
g * ans * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x))) | |