import operator import sys def subvals(x, ivs): x_ = list(x) for i, v in ivs: x_[i] = v return tuple(x_) def subval(x, i, v): x_ = list(x) x_[i] = v return tuple(x_) if sys.version_info >= (3,): def func(f): return f else: def func(f): return f.__func__ def toposort(end_node, parents=operator.attrgetter('parents')): child_counts = {} stack = [end_node] while stack: node = stack.pop() if node in child_counts: child_counts[node] += 1 else: child_counts[node] = 1 stack.extend(parents(node)) childless_nodes = [end_node] while childless_nodes: node = childless_nodes.pop() yield node for parent in parents(node): if child_counts[parent] == 1: childless_nodes.append(parent) else: child_counts[parent] -= 1 # -------------------- deprecation warnings ----------------------- import warnings deprecation_msg = """ The quick_grad_check function is deprecated. See the update guide: https://github.com/HIPS/autograd/blob/master/docs/updateguide.md""" def quick_grad_check(fun, arg0, extra_args=(), kwargs={}, verbose=True, eps=1e-4, rtol=1e-4, atol=1e-6, rs=None): warnings.warn(deprecation_msg) from autograd.test_util import check_grads fun_ = lambda arg0: fun(arg0, *extra_args, **kwargs) check_grads(fun_, modes=['rev'], order=1)(arg0)