from __future__ import absolute_import import scipy.stats import autograd.numpy as np from autograd.scipy.special import digamma from autograd.extend import primitive, defvjp rvs = primitive(scipy.stats.dirichlet.rvs) pdf = primitive(scipy.stats.dirichlet.pdf) logpdf = primitive(scipy.stats.dirichlet.logpdf) defvjp(logpdf,lambda ans, x, alpha: lambda g: g * (alpha - 1) / x, lambda ans, x, alpha: lambda g: g * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x))) # Same as log pdf, but multiplied by the pdf (ans). defvjp(pdf,lambda ans, x, alpha: lambda g: g * ans * (alpha - 1) / x, lambda ans, x, alpha: lambda g: g * ans * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)))