mxnet.autograd

Autograd for NDArray.

Functions

backward(heads[, head_grads, retain_graph, …])

Compute the gradients of heads w.r.t previously marked variables.

get_symbol(x)

Retrieve recorded computation history as Symbol.

grad(heads, variables[, head_grads, …])

Compute the gradients of heads w.r.t variables.

is_recording()

Get status on recording/not recording.

is_training()

Get status on training/predicting.

mark_variables(variables, gradients[, grad_reqs])

Mark NDArrays as variables to compute gradient for autograd.

pause([train_mode])

Returns a scope context to be used in ‘with’ statement for codes that do not need gradients to be calculated.

predict_mode()

Returns a scope context to be used in ‘with’ statement in which forward pass behavior is set to inference mode, without changing the recording states.

record([train_mode])

Returns an autograd recording scope context to be used in ‘with’ statement and captures code that needs gradients to be calculated.

set_recording(is_recording)

Set status to recording/not recording.

set_training(train_mode)

Set status to training/predicting.

train_mode()

Returns a scope context to be used in ‘with’ statement in which forward pass behavior is set to training mode, without changing the recording states.

Classes

Function()

Customize differentiation in autograd.

class Function[source]

Bases: object

Customize differentiation in autograd.

If you don’t want to use the gradients computed by the default chain-rule, you can use Function to customize differentiation for computation. You define your computation in the forward method and provide the customized differentiation in the backward method. During gradient computation, autograd will use the user-defined backward function instead of the default chain-rule. You can also cast to numpy array and back for some operations in forward and backward.

For example, a stable sigmoid function can be defined as:

class sigmoid(mx.autograd.Function):
    def forward(self, x):
        y = 1 / (1 + mx.nd.exp(-x))
        self.save_for_backward(y)
        return y

    def backward(self, dy):
        # backward takes as many inputs as forward's return value,
        # and returns as many NDArrays as forward's arguments.
        y, = self.saved_tensors
        return dy * y * (1-y)

Methods

backward(*output_grads)

Backward computation.

forward(*inputs)

Forward computation.

Then, the function can be used in the following way:

func = sigmoid()
x = mx.nd.random.uniform(shape=(10,))
x.attach_grad()

with mx.autograd.record():
    m = func(x)
    m.backward()
dx = x.grad.asnumpy()
backward(*output_grads)[source]

Backward computation.

Takes as many inputs as forward’s outputs, and returns as many NDArrays as forward’s inputs.

forward(*inputs)[source]

Forward computation.

backward(heads, head_grads=None, retain_graph=False, train_mode=True)[source]

Compute the gradients of heads w.r.t previously marked variables.

Parameters
  • heads (NDArray or list of NDArray) – Output NDArray(s)

  • head_grads (NDArray or list of NDArray or None) – Gradients with respect to heads.

  • train_mode (bool, optional) – Whether to do backward for training or predicting.

get_symbol(x)[source]

Retrieve recorded computation history as Symbol.

Parameters

x (NDArray) – Array representing the head of computation graph.

Returns

The retrieved Symbol.

Return type

Symbol

grad(heads, variables, head_grads=None, retain_graph=None, create_graph=False, train_mode=True)[source]

Compute the gradients of heads w.r.t variables. Gradients will be returned as new NDArrays instead of stored into variable.grad. Supports recording gradient graph for computing higher order gradients.

Note

Currently only a very limited set of operators support higher order gradients.

Parameters
  • heads (NDArray or list of NDArray) – Output NDArray(s)

  • variables (NDArray or list of NDArray) – Input variables to compute gradients for.

  • head_grads (NDArray or list of NDArray or None) – Gradients with respect to heads.

  • retain_graph (bool) – Whether to keep computation graph to differentiate again, instead of clearing history and release memory. Defaults to the same value as create_graph.

  • create_graph (bool) – Whether to record gradient graph for computing higher order

  • train_mode (bool, optional) – Whether to do backward for training or prediction.

Returns

Gradients with respect to variables.

Return type

NDArray or list of NDArray

Examples

>>> x = mx.nd.ones((1,))
>>> x.attach_grad()
>>> with mx.autograd.record():
...     z = mx.nd.elemwise_add(mx.nd.exp(x), x)
>>> dx = mx.autograd.grad(z, [x], create_graph=True)
>>> print(dx)
[
[ 3.71828175]
<NDArray 1 @cpu(0)>]
is_recording()[source]

Get status on recording/not recording.

Returns

Return type

Current state of recording.

is_training()[source]

Get status on training/predicting.

Returns

Return type

Current state of training/predicting.

mark_variables(variables, gradients, grad_reqs='write')[source]

Mark NDArrays as variables to compute gradient for autograd.

This is equivalent to the function .attach_grad() in a variable, but with this call we can set the gradient to any value.

Parameters
  • variables (NDArray or list of NDArray) –

  • gradients (NDArray or list of NDArray) –

  • grad_reqs (str or list of str) –

pause(train_mode=False)[source]

Returns a scope context to be used in ‘with’ statement for codes that do not need gradients to be calculated.

Example:

with autograd.record():
    y = model(x)
    backward([y])
    with autograd.pause():
        # testing, IO, gradient updates...
Parameters

train_mode (bool, default False) – Whether to do forward for training or predicting.

predict_mode()[source]

Returns a scope context to be used in ‘with’ statement in which forward pass behavior is set to inference mode, without changing the recording states.

Example:

with autograd.record():
    y = model(x)
    with autograd.predict_mode():
        y = sampling(y)
    backward([y])
record(train_mode=True)[source]

Returns an autograd recording scope context to be used in ‘with’ statement and captures code that needs gradients to be calculated.

Note

When forwarding with train_mode=False, the corresponding backward should also use train_mode=False, otherwise gradient is undefined.

Example:

with autograd.record():
    y = model(x)
    backward([y])
metric.update(...)
optim.step(...)
Parameters

train_mode (bool, default True) – Whether the forward pass is in training or predicting mode. This controls the behavior of some layers such as Dropout, BatchNorm.

set_recording(is_recording)[source]

Set status to recording/not recording. When recording, graph will be constructed for gradient computation.

Parameters

is_recording (bool) –

Returns

Return type

previous state before this set.

set_training(train_mode)[source]

Set status to training/predicting. This affects ctx.is_train in operator running context. For example, Dropout will drop inputs randomly when train_mode=True while simply passing through if train_mode=False.

Parameters

train_mode (bool) –

Returns

Return type

previous state before this set.

train_mode()[source]

Returns a scope context to be used in ‘with’ statement in which forward pass behavior is set to training mode, without changing the recording states.

Example:

y = model(x)
with autograd.train_mode():
    y = dropout(y)