Source code for mxnet.contrib.autograd
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
"""Autograd for NDArray."""
from __future__ import absolute_import
from __future__ import division
from array import array
import ctypes
import functools
from ..base import _LIB, check_call, string_types
from ..base import mx_uint, NDArrayHandle, c_array, c_array_buf, c_handle_array
# pylint: disable= unused-import
from ..ndarray import NDArray, zeros_like, _GRAD_REQ_MAP
[docs]def set_is_training(is_train):
"""Set status to training/not training. When training, graph will be constructed
for gradient computation. Operators will also run with ctx.is_train=True. For example,
Dropout will drop inputs randomly when is_train=True while simply passing through
if is_train=False.
Parameters
----------
is_train: bool
Returns
-------
previous state before this set.
"""
prev = ctypes.c_int()
check_call(_LIB.MXAutogradSetIsTraining(
ctypes.c_int(is_train), ctypes.byref(prev)))
check_call(_LIB.MXAutogradSetIsRecording(
ctypes.c_int(is_train), ctypes.byref(prev)))
return bool(prev.value)
[docs]class TrainingStateScope(object):
"""Scope for managing training state.
Example::
with TrainingStateScope(True):
y = model(x)
compute_gradient([y])
"""
def __init__(self, enter_state):
self._enter_state = enter_state
self._prev = None
def __enter__(self):
self._prev = set_is_training(self._enter_state)
def __exit__(self, ptype, value, trace):
if self._prev != self._enter_state:
set_is_training(self._prev)
[docs]def train_section():
"""Returns a training scope context to be used in 'with' statement
and captures training code.
Example::
with autograd.train_section():
y = model(x)
compute_gradient([y])
metric.update(...)
optim.step(...)
"""
return TrainingStateScope(True)
[docs]def test_section():
"""Returns a testing scope context to be used in 'with' statement
and captures testing code.
Example::
with autograd.train_section():
y = model(x)
compute_gradient([y])
with autograd.test_section():
# testing, IO, gradient updates...
"""
return TrainingStateScope(False)
[docs]def mark_variables(variables, gradients, grad_reqs='write'):
"""Mark NDArrays as variables to compute gradient for autograd.
Parameters
----------
variables: list of NDArray
gradients: list of NDArray
grad_reqs: list of string
"""
if isinstance(grad_reqs, string_types):
grad_reqs = [_GRAD_REQ_MAP[grad_reqs]]*len(variables)
else:
grad_reqs = [_GRAD_REQ_MAP[i] for i in grad_reqs]
check_call(_LIB.MXAutogradMarkVariables(
len(variables),
c_handle_array(variables),
c_array_buf(mx_uint, array('I', grad_reqs)),
c_handle_array(gradients)))
[docs]def backward(outputs, out_grads=None, retain_graph=False):
"""Compute the gradients of outputs w.r.t variables.
Parameters
----------
outputs: list of NDArray
out_grads: list of NDArray or None
"""
assert isinstance(outputs, (list, tuple)), \
"outputs must be a list or tuple of NDArrays"
if out_grads is None:
check_call(_LIB.MXAutogradBackward(
len(outputs),
c_handle_array(outputs),
ctypes.c_void_p(0),
ctypes.c_int(retain_graph)))
return
ograd_handles = []
for arr in out_grads:
if arr is not None:
ograd_handles.append(arr.handle)
else:
ograd_handles.append(NDArrayHandle(0))
assert len(ograd_handles) == len(outputs), \
"outputs and out_grads must have the same length"
check_call(_LIB.MXAutogradBackward(
len(outputs),
c_handle_array(outputs),
c_array(NDArrayHandle, ograd_handles),
ctypes.c_int(retain_graph)))
[docs]def grad_and_loss(func, argnum=None):
"""Return function that computes both gradient of arguments and loss value.
Parameters
----------
func: a python function
The forward (loss) function.
argnum: an int or a list of int
The index of argument to calculate gradient for.
Returns
-------
grad_and_loss_func: a python function
A function that would compute both the gradient of arguments and loss value.
"""
@functools.wraps(func)
def wrapped(*args):
"""Wrapped function."""
variables = args
if argnum is not None:
argnum_ = argnum if isinstance(argnum, list) else [argnum]
variables = [args[i] for i in argnum_]
for x in variables:
assert isinstance(x, NDArray), "type of autograd input should NDArray."
grads = [zeros_like(x) for x in variables]
mark_variables(variables, grads)
with train_section():
outputs = func(*args)
compute_gradient([outputs] if isinstance(outputs, NDArray) else outputs)
return grads, outputs
return wrapped
[docs]def grad(func, argnum=None):
"""Return function that computes gradient of arguments.
Parameters
----------
func: a python function
The forward (loss) function.
argnum: an int or a list of int
The index of argument to calculate gradient for.
Returns
-------
grad_func: a python function
A function that would compute the gradient of arguments.
Examples
--------
>>> # autograd supports dynamic graph which is changed
>>> # every instance
>>> def func(x):
>>> r = random.randint(0, 1)
>>> if r % 2:
>>> return x**2
>>> else:
>>> return x/3
>>> # use `grad(func)` to get the gradient function
>>> for x in range(10):
>>> grad_func = grad(func)
>>> inputs = nd.array([[1, 2, 3], [4, 5, 6]])
>>> grad_vals = grad_func(inputs)
"""
grad_with_loss_func = grad_and_loss(func, argnum)
@functools.wraps(grad_with_loss_func)
def wrapped(*args):
return grad_with_loss_func(*args)[0]
return wrapped