from __future__ import absolute_import, division import autograd.numpy as np import scipy.stats from autograd.extend import primitive, defvjp from autograd.numpy.numpy_vjps import unbroadcast_f from autograd.scipy.special import gamma cdf = primitive(scipy.stats.chi2.cdf) logpdf = primitive(scipy.stats.chi2.logpdf) pdf = primitive(scipy.stats.chi2.pdf) def grad_chi2_logpdf(x, df): return np.where(df % 1 == 0, (df - x - 2) / (2 * x), 0) defvjp(cdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * np.power(2., -df/2) * np.exp(-x/2) * np.power(x, df/2 - 1) / gamma(df/2)), argnums=[0]) defvjp(logpdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * grad_chi2_logpdf(x, df)), argnums=[0]) defvjp(pdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * ans * grad_chi2_logpdf(x, df)), argnums=[0])