File size: 1,482 Bytes
ab4488b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import operator
import sys
def subvals(x, ivs):
x_ = list(x)
for i, v in ivs:
x_[i] = v
return tuple(x_)
def subval(x, i, v):
x_ = list(x)
x_[i] = v
return tuple(x_)
if sys.version_info >= (3,):
def func(f): return f
else:
def func(f): return f.__func__
def toposort(end_node, parents=operator.attrgetter('parents')):
child_counts = {}
stack = [end_node]
while stack:
node = stack.pop()
if node in child_counts:
child_counts[node] += 1
else:
child_counts[node] = 1
stack.extend(parents(node))
childless_nodes = [end_node]
while childless_nodes:
node = childless_nodes.pop()
yield node
for parent in parents(node):
if child_counts[parent] == 1:
childless_nodes.append(parent)
else:
child_counts[parent] -= 1
# -------------------- deprecation warnings -----------------------
import warnings
deprecation_msg = """
The quick_grad_check function is deprecated. See the update guide:
https://github.com/HIPS/autograd/blob/master/docs/updateguide.md"""
def quick_grad_check(fun, arg0, extra_args=(), kwargs={}, verbose=True,
eps=1e-4, rtol=1e-4, atol=1e-6, rs=None):
warnings.warn(deprecation_msg)
from autograd.test_util import check_grads
fun_ = lambda arg0: fun(arg0, *extra_args, **kwargs)
check_grads(fun_, modes=['rev'], order=1)(arg0)
|