Source code for mxnet.contrib.autograd
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
"""Autograd for NDArray."""
from array import array
import ctypes
import functools
from ..base import _LIB, check_call, string_types
from ..base import mx_uint, NDArrayHandle, c_array, c_array_buf, c_handle_array
# pylint: disable= unused-import
from ..ndarray import NDArray, zeros_like, _GRAD_REQ_MAP
[docs]def set_is_training(is_train):
    """Set status to training/not training. When training, graph will be constructed
    for gradient computation. Operators will also run with ctx.is_train=True. For example,
    Dropout will drop inputs randomly when is_train=True while simply passing through
    if is_train=False.
    Parameters
    ----------
    is_train: bool
    Returns
    -------
    previous state before this set.
    """
    prev = ctypes.c_int()
    check_call(_LIB.MXAutogradSetIsTraining(
        ctypes.c_int(is_train), ctypes.byref(prev)))
    check_call(_LIB.MXAutogradSetIsRecording(
        ctypes.c_int(is_train), ctypes.byref(prev)))
    return bool(prev.value)
[docs]class TrainingStateScope(object):
    """Scope for managing training state.
    Example::
        with TrainingStateScope(True):
            y = model(x)
            compute_gradient([y])
    """
    def __init__(self, enter_state):
        self._enter_state = enter_state
        self._prev = None
    def __enter__(self):
        self._prev = set_is_training(self._enter_state)
    def __exit__(self, ptype, value, trace):
        if self._prev != self._enter_state:
            set_is_training(self._prev)
[docs]def train_section():
    """Returns a training scope context to be used in 'with' statement
    and captures training code.
    Example::
        with autograd.train_section():
            y = model(x)
            compute_gradient([y])
        metric.update(...)
        optim.step(...)
    """
    return TrainingStateScope(True)
[docs]def test_section():
    """Returns a testing scope context to be used in 'with' statement
    and captures testing code.
    Example::
        with autograd.train_section():
            y = model(x)
            compute_gradient([y])
            with autograd.test_section():
                # testing, IO, gradient updates...
    """
    return TrainingStateScope(False)
[docs]def mark_variables(variables, gradients, grad_reqs='write'):
    """Mark NDArrays as variables to compute gradient for autograd.
    Parameters
    ----------
    variables: list of NDArray
    gradients: list of NDArray
    grad_reqs: list of string
    """
    if isinstance(grad_reqs, string_types):
        grad_reqs = [_GRAD_REQ_MAP[grad_reqs]]*len(variables)
    else:
        grad_reqs = [_GRAD_REQ_MAP[i] for i in grad_reqs]
    check_call(_LIB.MXAutogradMarkVariables(
        len(variables),
        c_handle_array(variables),
        c_array_buf(mx_uint, array('I', grad_reqs)),
        c_handle_array(gradients)))
[docs]def backward(outputs, out_grads=None, retain_graph=False):
    """Compute the gradients of outputs w.r.t variables.
    Parameters
    ----------
    outputs: list of NDArray
    out_grads: list of NDArray or None
    """
    assert isinstance(outputs, (list, tuple)), \
        "outputs must be a list or tuple of NDArrays"
    if out_grads is None:
        check_call(_LIB.MXAutogradBackward(
            len(outputs),
            c_handle_array(outputs),
            ctypes.c_void_p(0),
            ctypes.c_int(retain_graph)))
        return
    ograd_handles = []
    for arr in out_grads:
        if arr is not None:
            ograd_handles.append(arr.handle)
        else:
            ograd_handles.append(NDArrayHandle(0))
    assert len(ograd_handles) == len(outputs), \
        "outputs and out_grads must have the same length"
    check_call(_LIB.MXAutogradBackward(
        len(outputs),
        c_handle_array(outputs),
        c_array(NDArrayHandle, ograd_handles),
        ctypes.c_int(retain_graph)))
[docs]def grad_and_loss(func, argnum=None):
    """Return function that computes both gradient of arguments and loss value.
    Parameters
    ----------
    func: a python function
        The forward (loss) function.
    argnum: an int or a list of int
        The index of argument to calculate gradient for.
    Returns
    -------
    grad_and_loss_func: a python function
        A function that would compute both the gradient of arguments and loss value.
    """
    @functools.wraps(func)
    def wrapped(*args):
        """Wrapped function."""
        variables = args
        if argnum is not None:
            argnum_ = argnum if isinstance(argnum, list) else [argnum]
            variables = [args[i] for i in argnum_]
        for x in variables:
            assert isinstance(x, NDArray), "type of autograd input should NDArray."
        grads = [zeros_like(x) for x in variables]
        mark_variables(variables, grads)
        with train_section():
            outputs = func(*args)
        compute_gradient([outputs] if isinstance(outputs, NDArray) else outputs)
        return grads, outputs
    return wrapped
[docs]def grad(func, argnum=None):
    """Return function that computes gradient of arguments.
    Parameters
    ----------
    func: a python function
        The forward (loss) function.
    argnum: an int or a list of int
        The index of argument to calculate gradient for.
    Returns
    -------
    grad_func: a python function
        A function that would compute the gradient of arguments.
    Examples
    --------
    >>> # autograd supports dynamic graph which is changed
    >>> # every instance
    >>> def func(x):
    >>>     r = random.randint(0, 1)
    >>>     if r % 2:
    >>>         return x**2
    >>>     else:
    >>>         return x/3
    >>> # use `grad(func)` to get the gradient function
    >>> for x in range(10):
    >>>     grad_func = grad(func)
    >>>     inputs = nd.array([[1, 2, 3], [4, 5, 6]])
    >>>     grad_vals = grad_func(inputs)
    """
    grad_with_loss_func = grad_and_loss(func, argnum)
    @functools.wraps(grad_with_loss_func)
    def wrapped(*args):
        return grad_with_loss_func(*args)[0]
    return wrapped
Did this page help you?
    Yes
        No
    Thanks for your feedback!
