File size: 1,332 Bytes
ab4488b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from __future__ import absolute_import

import autograd.numpy as np
import scipy.stats
from autograd.extend import primitive, defvjp
from autograd.numpy.numpy_vjps import unbroadcast_f
from autograd.scipy.special import beta, psi

cdf = primitive(scipy.stats.beta.cdf)
logpdf = primitive(scipy.stats.beta.logpdf)
pdf = primitive(scipy.stats.beta.pdf)

def grad_beta_logpdf_arg0(x, a, b):
    return (1 + a * (x-1) + x * (b-2)) / (x * (x-1))

def grad_beta_logpdf_arg1(x, a, b):
    return np.log(x) - psi(a) + psi(a + b)

def grad_beta_logpdf_arg2(x, a, b):
    return np.log1p(-x) - psi(b) + psi(a + b)

defvjp(cdf, lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * np.power(x, a-1) * np.power(1-x, b-1) / beta(a, b)), argnums=[0])
defvjp(logpdf,
       lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * grad_beta_logpdf_arg0(x, a, b)),
       lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * grad_beta_logpdf_arg1(x, a, b)),
       lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * grad_beta_logpdf_arg2(x, a, b)))
defvjp(pdf,
       lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * ans * grad_beta_logpdf_arg0(x, a, b)),
       lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * ans * grad_beta_logpdf_arg1(x, a, b)),
       lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * ans * grad_beta_logpdf_arg2(x, a, b)))