import warnings from contextlib import contextmanager from collections import defaultdict from .util import subvals, toposort from .wrap_util import wraps def trace(start_node, fun, x): with trace_stack.new_trace() as t: start_box = new_box(x, t, start_node) end_box = fun(start_box) if isbox(end_box) and end_box._trace == start_box._trace: return end_box._value, end_box._node else: warnings.warn("Output seems independent of input.") return end_box, None class Node(object): __slots__ = [] def __init__(self, value, fun, args, kwargs, parent_argnums, parents): assert False def initialize_root(self, *args, **kwargs): assert False @classmethod def new_root(cls, *args, **kwargs): root = cls.__new__(cls) root.initialize_root(*args, **kwargs) return root def primitive(f_raw): """ Wraps a function so that its gradient can be specified and its invocation can be recorded. For examples, see the docs.""" @wraps(f_raw) def f_wrapped(*args, **kwargs): boxed_args, trace, node_constructor = find_top_boxed_args(args) if boxed_args: argvals = subvals(args, [(argnum, box._value) for argnum, box in boxed_args]) if f_wrapped in notrace_primitives[node_constructor]: return f_wrapped(*argvals, **kwargs) parents = tuple(box._node for _ , box in boxed_args) argnums = tuple(argnum for argnum, _ in boxed_args) ans = f_wrapped(*argvals, **kwargs) node = node_constructor(ans, f_wrapped, argvals, kwargs, argnums, parents) return new_box(ans, trace, node) else: return f_raw(*args, **kwargs) f_wrapped.fun = f_raw f_wrapped._is_autograd_primitive = True return f_wrapped notrace_primitives = defaultdict(set) def register_notrace(trace_type, primitive_fun): notrace_primitives[trace_type].add(primitive_fun) def notrace_primitive(f_raw): @wraps(f_raw) def f_wrapped(*args, **kwargs): argvals = map(getval, args) return f_raw(*argvals, **kwargs) f_wrapped._is_primitive = True return f_wrapped def find_top_boxed_args(args): top_trace = -1 top_boxes = [] top_node_type = None for argnum, arg in enumerate(args): if isbox(arg): trace = arg._trace if trace > top_trace: top_boxes = [(argnum, arg)] top_trace = trace top_node_type = type(arg._node) elif trace == top_trace: top_boxes.append((argnum, arg)) return top_boxes, top_trace, top_node_type class TraceStack(object): def __init__(self): self.top = -1 @contextmanager def new_trace(self): self.top += 1 yield self.top self.top -= 1 trace_stack = TraceStack() class Box(object): type_mappings = {} types = set() __slots__ = ['_value', '_trace', '_node'] def __init__(self, value, trace, node): self._value = value self._node = node self._trace = trace def __bool__(self): return bool(self._value) __nonzero__ = __bool__ def __str__(self): return "Autograd {0} with value {1}".format( type(self).__name__, str(self._value)) @classmethod def register(cls, value_type): Box.types.add(cls) Box.type_mappings[value_type] = cls Box.type_mappings[cls] = cls box_type_mappings = Box.type_mappings def new_box(value, trace, node): try: return box_type_mappings[type(value)](value, trace, node) except KeyError: raise TypeError("Can't differentiate w.r.t. type {}".format(type(value))) box_types = Box.types isbox = lambda x: type(x) in box_types # almost 3X faster than isinstance(x, Box) getval = lambda x: getval(x._value) if isbox(x) else x