mxnet.optimizer

Optimizer API of MXNet.

Classes

AdaDelta([rho, epsilon])

The AdaDelta optimizer.

AdaGrad([eps])

AdaGrad optimizer.

Adam([learning_rate, beta1, beta2, epsilon, …])

The Adam optimizer.

Adamax([learning_rate, beta1, beta2])

The AdaMax optimizer.

DCASGD([momentum, lamda])

The DCASGD optimizer.

FTML([beta1, beta2, epsilon])

The FTML optimizer.

Ftrl([lamda1, learning_rate, beta])

The Ftrl optimizer.

LARS([momentum, lazy_update, eta, eps, …])

the LARS optimizer from ‘Large Batch Training of Convolution Networks’ (https://arxiv.org/abs/1708.03888)

LBSGD([momentum, multi_precision, …])

The Large Batch SGD optimizer with momentum and weight decay.

NAG([momentum])

Nesterov accelerated gradient.

Nadam([learning_rate, beta1, beta2, …])

The Nesterov Adam optimizer.

Optimizer([rescale_grad, param_idx2name, …])

The base class inherited by all optimizers.

RMSProp([learning_rate, gamma1, gamma2, …])

The RMSProp optimizer.

SGD([momentum, lazy_update])

The SGD optimizer with momentum and weight decay.

SGLD(**kwargs)

Stochastic Gradient Riemannian Langevin Dynamics.

Signum([learning_rate, momentum, wd_lh])

The Signum optimizer that takes the sign of gradient or momentum.

LAMB([learning_rate, beta1, beta2, epsilon, …])

LAMB Optimizer.

Test(**kwargs)

The Test optimizer

Updater(optimizer)

Updater for kvstore.

ccSGD(*args, **kwargs)

[DEPRECATED] Same as SGD. Left here for backward compatibility.

Functions

NDabs([data, out, name])

Returns element-wise absolute value of the input.

create(name, **kwargs)

Instantiates an optimizer with a given name and kwargs.

get_updater(optimizer)

Returns a closure of the updater needed for kvstore.

register(klass)

Registers a new optimizer.

class mxnet.optimizer.AdaDelta(rho=0.9, epsilon=1e-05, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The AdaDelta optimizer.

This class implements AdaDelta, an optimizer described in ADADELTA: An adaptive learning rate method, available at https://arxiv.org/abs/1212.5701.

This optimizer updates each weight by:

grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
acc_grad = rho * acc_grad + (1. - rho) * grad * grad
delta = sqrt(acc_delta + epsilon) / sqrt(acc_grad + epsilon) * grad
acc_delta = rho * acc_delta + (1. - rho) * delta * delta
weight -= (delta + wd * weight)

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • rho (float) – Decay rate for both squared gradients and delta.

  • epsilon (float) – Small value to avoid division by 0.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.AdaGrad(eps=1e-07, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

AdaGrad optimizer.

This class implements the AdaGrad optimizer described in Adaptive Subgradient Methods for Online Learning and Stochastic Optimization, and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.

This optimizer updates each weight by:

grad = clip(grad * rescale_grad, clip_gradient)
history += square(grad)
div = grad / sqrt(history + float_stable_eps)
weight += (div + weight * wd) * -lr

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters

eps (float, optional) – Initial value of the history accumulator. Avoids division by 0.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Adam optimizer.

This class implements the optimizer described in Adam: A Method for Stochastic Optimization, available at http://arxiv.org/abs/1412.6980.

If the storage types of grad is row_sparse, and lazy_update is True, lazy updates at step t are applied by:

for row in grad.indices:
    rescaled_grad[row] = clip(grad[row] * rescale_grad + wd * weight[row], clip_gradient)
    m[row] = beta1 * m[row] + (1 - beta1) * rescaled_grad[row]
    v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2)
    lr = learning_rate * sqrt(1 - beta1**t) / (1 - beta2**t)
    w[row] = w[row] - lr * m[row] / (sqrt(v[row]) + epsilon)

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

The lazy update only updates the mean and var for the weights whose row_sparse gradient indices appear in the current batch, rather than updating it for all indices. Compared with the original update, it can provide large improvements in model training throughput for some applications. However, it provides slightly different semantics than the original update, and may lead to different empirical results.

Otherwise, standard updates at step t are applied by:

rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
lr = learning_rate * sqrt(1 - beta1**t) / (1 - beta2**t)
w = w - lr * m / (sqrt(v) + epsilon)

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

For details of the update algorithm, see adam_update.

Parameters
  • beta1 (float, optional) – Exponential decay rate for the first moment estimates.

  • beta2 (float, optional) – Exponential decay rate for the second moment estimates.

  • epsilon (float, optional) – Small value to avoid division by 0.

  • lazy_update (bool, optional) – Default is True. If True, lazy updates are applied if the storage types of weight and grad are both row_sparse.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.Adamax(learning_rate=0.002, beta1=0.9, beta2=0.999, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The AdaMax optimizer.

It is a variant of Adam based on the infinity norm available at http://arxiv.org/abs/1412.6980 Section 7.

The optimizer updates the weight by:

grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
m = beta1 * m_t + (1 - beta1) * grad
u = maximum(beta2 * u, abs(grad))
weight -= lr / (1 - beta1**t) * m / u

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • beta1 (float, optional) – Exponential decay rate for the first moment estimates.

  • beta2 (float, optional) – Exponential decay rate for the second moment estimates.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.DCASGD(momentum=0.0, lamda=0.04, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The DCASGD optimizer.

This class implements the optimizer described in Asynchronous Stochastic Gradient Descent with Delay Compensation for Distributed Deep Learning, available at https://arxiv.org/abs/1609.08326.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • momentum (float, optional) – The momentum value.

  • lamda (float, optional) – Scale DC value.

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.FTML(beta1=0.6, beta2=0.999, epsilon=1e-08, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The FTML optimizer.

This class implements the optimizer described in FTML - Follow the Moving Leader in Deep Learning, available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.

Denote time step by t. The optimizer updates the weight by:

rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
v = beta2 * v + (1 - beta2) * square(rescaled_grad)
d_t = (1 - power(beta1, t)) / lr * square_root(v / (1 - power(beta2, t))) + epsilon)
z = beta1 * z + (1 - beta1) * rescaled_grad - (d_t - beta1 * d_(t-1)) * weight
weight = - z / d_t

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

For details of the update algorithm, see ftml_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • beta1 (float, optional) – 0 < beta1 < 1. Generally close to 0.5.

  • beta2 (float, optional) – 0 < beta2 < 1. Generally close to 1.

  • epsilon (float, optional) – Small value to avoid division by 0.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.Ftrl(lamda1=0.01, learning_rate=0.1, beta=1, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Ftrl optimizer.

Referenced from Ad Click Prediction: a View from the Trenches, available at http://dl.acm.org/citation.cfm?id=2488200.

eta :
\[\eta_{t,i} = \frac{learningrate}{\beta+\sqrt{\sum_{s=1}^tg_{s,i}^2}}\]

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

The optimizer updates the weight by:

rescaled_grad = clip(grad * rescale_grad, clip_gradient)
z += rescaled_grad - (sqrt(n + rescaled_grad**2) - sqrt(n)) * weight / learning_rate
n += rescaled_grad**2
w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learning_rate + wd) * (abs(z) > lamda1)

If the storage types of weight, state and grad are all row_sparse, sparse updates are applied by:

for row in grad.indices:
    rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient)
    z[row] += rescaled_grad[row] - (sqrt(n[row] + rescaled_grad[row]**2) - sqrt(n[row])) * weight[row] / learning_rate
    n[row] += rescaled_grad[row]**2
    w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learning_rate + wd) * (abs(z[row]) > lamda1)

The sparse update only updates the z and n for the weights whose row_sparse gradient indices appear in the current batch, rather than updating it for all indices. Compared with the original update, it can provide large improvements in model training throughput for some applications. However, it provides slightly different semantics than the original update, and may lead to different empirical results.

For details of the update algorithm, see ftrl_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • lamda1 (float, optional) – L1 regularization coefficient.

  • learning_rate (float, optional) – The initial learning rate.

  • beta (float, optional) – Per-coordinate learning rate correlation parameter.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.LARS(momentum=0.0, lazy_update=True, eta=0.001, eps=0, momentum_correction=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

the LARS optimizer from ‘Large Batch Training of Convolution Networks’ (https://arxiv.org/abs/1708.03888)

Behave mostly like SGD with momentum and weight decay but is scaling adaptively the learning for each layer (except bias and batch norm parameters): w_norm = L2norm(weights) g_norm = L2norm(gradients) if w_norm > 0 and g_norm > 0:

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

create_state_multi_precision(index, weight)

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

set_wd_mult(args_wd_mult)

Sets an individual weight decay multiplier for each parameter.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

update_multi_precision(index, weight, grad, …)

Updates the given parameter using the corresponding gradient and state.

lr_layer = lr * lr_mult * eta * w_norm / (g_norm + weight_decay * w_norm + eps)

else:

lr_layer = lr * lr_mult

Parameters
  • momentum (float, optional) – The momentum value.

  • lazy_update (bool, optional) – Default is True. If True, lazy updates are applied if the storage types of weight and grad are both row_sparse.

  • lars_eta (float, optional) – LARS coefficient used to scale the learning rate. Default set to 0.001.

  • lars_epsilon (float, optional) – Optional epsilon in case of very small gradients. Default set to 0.

  • momentum_correction (bool, optional) – If True scale momentum w.r.t global learning rate change (with an lr_scheduler) as indicated in ‘Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour` (https://arxiv.org/pdf/1706.02677.pdf) Default set to True.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

create_state_multi_precision(index, weight)[source]

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

This method is provided to perform automatic mixed precision training for optimizers that do not support it themselves.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

set_wd_mult(args_wd_mult)[source]

Sets an individual weight decay multiplier for each parameter.

By default, if param_idx2name was provided in the constructor, the weight decay multipler is set as 0 for all parameters whose name don’t end with _weight or _gamma.

Note

The default weight decay multiplier for a Variable can be set with its wd_mult argument in the constructor.

Parameters

args_wd_mult (dict of string/int to float) –

For each of its key-value entries, the weight decay multipler for the parameter specified in the key will be set as the given value.

You can specify the parameter with either its name or its index. If you use the name, you should pass sym in the constructor, and the name you specified in the key of args_lr_mult should match the name of the parameter in sym. If you use the index, it should correspond to the index of the parameter used in the update method.

Specifying a parameter by its index is only supported for backward compatibility, and we recommend to use the name instead.

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

update_multi_precision(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state. Mixed precision version.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.LBSGD(momentum=0.0, multi_precision=False, warmup_strategy='linear', warmup_epochs=5, batch_scale=1, updates_per_epoch=32, begin_epoch=0, num_epochs=60, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Large Batch SGD optimizer with momentum and weight decay.

The optimizer updates the weight by:

state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
weight = weight - state

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

For details of the update algorithm see sgd_update and sgd_mom_update. In addition to the SGD updates the LBSGD optimizer uses the LARS, Layer-wise Adaptive Rate Scaling, algorithm to have a separate learning rate for each layer of the network, which leads to better stability over large batch sizes.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • momentum (float, optional) – The momentum value.

  • multi_precision (bool, optional) – Flag to control the internal precision of the optimizer. False: results in using the same precision as the weights (default), True: makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.

  • warmup_strategy (string ('linear', 'power2', 'sqrt'. , 'lars' default : 'linear')) –

  • warmup_epochs (unsigned, default: 5) –

  • batch_scale (unsigned, default: 1 (same as batch size * numworkers)) –

  • updates_per_epoch (updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)) –

  • begin_epoch (unsigned, default 0, starting epoch.) –

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.NAG(momentum=0.0, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

Nesterov accelerated gradient.

This optimizer updates each weight by:

state = momentum * state + grad + wd * weight
weight = weight - (lr * (grad + momentum * state))

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

create_state_multi_precision(index, weight)

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

update_multi_precision(index, weight, grad, …)

Updates the given parameter using the corresponding gradient and state.

Parameters
  • momentum (float, optional) – The momentum value.

  • multi_precision (bool, optional) – Flag to control the internal precision of the optimizer. False: results in using the same precision as the weights (default), True: makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

create_state_multi_precision(index, weight)[source]

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

This method is provided to perform automatic mixed precision training for optimizers that do not support it themselves.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

update_multi_precision(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state. Mixed precision version.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

mxnet.optimizer.NDabs(data=None, out=None, name=None, **kwargs)

Returns element-wise absolute value of the input.

Example:

abs([-2, 0, 3]) = [2, 0, 3]

The storage type of abs output depends upon the input storage type:

  • abs(default) = default

  • abs(row_sparse) = row_sparse

  • abs(csr) = csr

Defined in src/operator/tensor/elemwise_unary_op_basic.cc:L721

Parameters
  • data (NDArray) – The input array.

  • out (NDArray, optional) – The output NDArray to hold the result.

Returns

out – The output of this function.

Return type

NDArray or list of NDArrays

class mxnet.optimizer.Nadam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, schedule_decay=0.004, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Nesterov Adam optimizer.

Much like Adam is essentially RMSprop with momentum, Nadam is Adam RMSprop with Nesterov momentum available at http://cs229.stanford.edu/proj2015/054_report.pdf.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • beta1 (float, optional) – Exponential decay rate for the first moment estimates.

  • beta2 (float, optional) – Exponential decay rate for the second moment estimates.

  • epsilon (float, optional) – Small value to avoid division by 0.

  • schedule_decay (float, optional) – Exponential decay rate for the momentum schedule

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.Optimizer(rescale_grad=1.0, param_idx2name=None, wd=0.0, clip_gradient=None, learning_rate=None, lr_scheduler=None, sym=None, begin_num_update=0, multi_precision=False, param_dict=None)[source]

Bases: object

The base class inherited by all optimizers.

Parameters
  • rescale_grad (float, optional, default 1.0) – Multiply the gradient with rescale_grad before updating. Often choose to be 1.0/batch_size.

  • param_idx2name (dict from int to string, optional, default None) – A dictionary that maps int index to string name.

  • clip_gradient (float, optional, default None) – Clip the gradient by projecting onto the box [-clip_gradient, clip_gradient].

  • learning_rate (float) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • lr_scheduler (LRScheduler, optional, default None) – The learning rate scheduler.

  • wd (float, optional, default 0.0) – The weight decay (or L2 regularization) coefficient. Modifies objective by adding a penalty for having large weights.

  • sym (Symbol, optional, default None) – The Symbol this optimizer is applying to.

  • begin_num_update (int, optional, default 0) – The initial number of updates.

  • multi_precision (bool, optional, default False) – Flag to control the internal precision of the optimizer. False: results in using the same precision as the weights (default), True: makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.

  • param_dict (dict of int -> gluon.Parameter, default None) – Dictionary of parameter index to gluon.Parameter, used to lookup parameter attributes such as lr_mult, wd_mult, etc. param_dict shall not be deep copied.

  • Properties

  • ----------

  • learning_rate – The current learning rate of the optimizer. Given an Optimizer object optimizer, its learning rate can be accessed as optimizer.learning_rate.

Methods

create_optimizer(name, **kwargs)

Instantiates an optimizer with a given name and kwargs.

create_state(index, weight)

Creates auxiliary state for a given weight.

create_state_multi_precision(index, weight)

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

register(klass)

Registers a new optimizer.

set_learning_rate(lr)

Sets a new learning rate of the optimizer.

set_lr_mult(args_lr_mult)

Sets an individual learning rate multiplier for each parameter.

set_lr_scale(args_lrscale)

[DEPRECATED] Sets lr scale. Use set_lr_mult instead.

set_wd_mult(args_wd_mult)

Sets an individual weight decay multiplier for each parameter.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

update_multi_precision(index, weight, grad, …)

Updates the given parameter using the corresponding gradient and state.

static create_optimizer(name, **kwargs)[source]

Instantiates an optimizer with a given name and kwargs.

Note

We can use the alias create for Optimizer.create_optimizer.

Parameters
  • name (str) – Name of the optimizer. Should be the name of a subclass of Optimizer. Case insensitive.

  • kwargs (dict) – Parameters for the optimizer.

Returns

An instantiated optimizer.

Return type

Optimizer

Examples

>>> sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
>>> type(sgd)
<class 'mxnet.optimizer.SGD'>
>>> adam = mx.optimizer.create('adam', learning_rate=.1)
>>> type(adam)
<class 'mxnet.optimizer.Adam'>
create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

create_state_multi_precision(index, weight)[source]

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

This method is provided to perform automatic mixed precision training for optimizers that do not support it themselves.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

static register(klass)[source]

Registers a new optimizer.

Once an optimizer is registered, we can create an instance of this optimizer with create_optimizer later.

Examples

>>> @mx.optimizer.Optimizer.register
... class MyOptimizer(mx.optimizer.Optimizer):
...     pass
>>> optim = mx.optimizer.Optimizer.create_optimizer('MyOptimizer')
>>> print(type(optim))
<class '__main__.MyOptimizer'>
set_learning_rate(lr)[source]

Sets a new learning rate of the optimizer.

Parameters

lr (float) – The new learning rate of the optimizer.

set_lr_mult(args_lr_mult)[source]

Sets an individual learning rate multiplier for each parameter.

If you specify a learning rate multiplier for a parameter, then the learning rate for the parameter will be set as the product of the global learning rate self.lr and its multiplier.

Note

The default learning rate multiplier of a Variable can be set with lr_mult argument in the constructor.

Parameters

args_lr_mult (dict of str/int to float) –

For each of its key-value entries, the learning rate multipler for the parameter specified in the key will be set as the given value.

You can specify the parameter with either its name or its index. If you use the name, you should pass sym in the constructor, and the name you specified in the key of args_lr_mult should match the name of the parameter in sym. If you use the index, it should correspond to the index of the parameter used in the update method.

Specifying a parameter by its index is only supported for backward compatibility, and we recommend to use the name instead.

set_lr_scale(args_lrscale)[source]

[DEPRECATED] Sets lr scale. Use set_lr_mult instead.

set_wd_mult(args_wd_mult)[source]

Sets an individual weight decay multiplier for each parameter.

By default, if param_idx2name was provided in the constructor, the weight decay multipler is set as 0 for all parameters whose name don’t end with _weight or _gamma.

Note

The default weight decay multiplier for a Variable can be set with its wd_mult argument in the constructor.

Parameters

args_wd_mult (dict of string/int to float) –

For each of its key-value entries, the weight decay multipler for the parameter specified in the key will be set as the given value.

You can specify the parameter with either its name or its index. If you use the name, you should pass sym in the constructor, and the name you specified in the key of args_lr_mult should match the name of the parameter in sym. If you use the index, it should correspond to the index of the parameter used in the update method.

Specifying a parameter by its index is only supported for backward compatibility, and we recommend to use the name instead.

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

update_multi_precision(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state. Mixed precision version.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.RMSProp(learning_rate=0.001, gamma1=0.9, gamma2=0.9, epsilon=1e-08, centered=False, clip_weights=None, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The RMSProp optimizer.

Two versions of RMSProp are implemented:

If centered=False, we follow http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf by Tieleman & Hinton, 2012. For details of the update algorithm see rmsprop_update.

If centered=True, we follow http://arxiv.org/pdf/1308.0850v5.pdf (38)-(45) by Alex Graves, 2013. For details of the update algorithm see rmspropalex_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • gamma1 (float, optional) – A decay factor of moving average over past squared gradient.

  • gamma2 (float, optional) – A “momentum” factor. Only used if centered`=``True`.

  • epsilon (float, optional) – Small value to avoid division by 0.

  • centered (bool, optional) –

    Flag to control which version of RMSProp to use.:

    True: will use Graves's version of `RMSProp`,
    False: will use Tieleman & Hinton's version of `RMSProp`.
    

  • clip_weights (float, optional) – Clips weights into range [-clip_weights, clip_weights].

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.SGD(momentum=0.0, lazy_update=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The SGD optimizer with momentum and weight decay.

If the storage types of grad is row_sparse and lazy_update is True, lazy updates are applied by:

for row in grad.indices:
    rescaled_grad[row] = lr * (rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row])
    state[row] = momentum[row] * state[row] + rescaled_grad[row]
    weight[row] = weight[row] - state[row]

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

create_state_multi_precision(index, weight)

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

update_multi_precision(index, weight, grad, …)

Updates the given parameter using the corresponding gradient and state.

The sparse update only updates the momentum for the weights whose row_sparse gradient indices appear in the current batch, rather than updating it for all indices. Compared with the original update, it can provide large improvements in model training throughput for some applications. However, it provides slightly different semantics than the original update, and may lead to different empirical results.

In the case when update_on_kvstore is set to False (either globally via MXNET_UPDATE_ON_KVSTORE=0 environment variable or as a parameter in Trainer) SGD optimizer can perform aggregated update of parameters, which may lead to improved performance. The aggregation size is controlled by MXNET_OPTIMIZER_AGGREGATION_SIZE environment variable and defaults to 4.

Otherwise, standard updates are applied by:

rescaled_grad = lr * (rescale_grad * clip(grad, clip_gradient) + wd * weight)
state = momentum * state + rescaled_grad
weight = weight - state

For details of the update algorithm see sgd_update and sgd_mom_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • momentum (float, optional) – The momentum value.

  • lazy_update (bool, optional) – Default is True. If True, lazy updates are applied if the storage types of weight and grad are both row_sparse.

  • multi_precision (bool, optional) – Flag to control the internal precision of the optimizer. False: results in using the same precision as the weights (default), True: makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

create_state_multi_precision(index, weight)[source]

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

This method is provided to perform automatic mixed precision training for optimizers that do not support it themselves.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

update_multi_precision(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state. Mixed precision version.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.SGLD(**kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

Stochastic Gradient Riemannian Langevin Dynamics.

This class implements the optimizer described in the paper Stochastic Gradient Riemannian Langevin Dynamics on the Probability Simplex, available at https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf.

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.Signum(learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Signum optimizer that takes the sign of gradient or momentum.

The optimizer updates the weight by:

rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + (1-momentum)*rescaled_grad
weight = (1 - lr * wd_lh) * weight - lr * sign(state)

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

References

Jeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli & Anima Anandkumar. (2018). signSGD: Compressed Optimisation for Non-Convex Problems. In ICML’18.

See: https://arxiv.org/abs/1802.04434

For details of the update algorithm see signsgd_update and signum_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • momentum (float, optional) – The momentum value.

  • wd_lh (float, optional) – The amount of decoupled weight decay regularization, see details in the original paper at:https://arxiv.org/abs/1711.05101

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.LAMB(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-06, lower_bound=None, upper_bound=None, bias_correction=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

LAMB Optimizer.

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

update(index, weight, grad, state)

Updates the given parameter using the corresponding gradient and state.

update_multi_precision(index, weight, grad, …)

Updates the given parameter using the corresponding gradient and state.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

update_multi_precision(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state. Mixed precision version.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().

class mxnet.optimizer.Test(**kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Test optimizer

Methods

create_state(index, weight)

Creates a state to duplicate weight.

update(index, weight, grad, state)

Performs w += rescale_grad * grad.

create_state(index, weight)[source]

Creates a state to duplicate weight.

update(index, weight, grad, state)[source]

Performs w += rescale_grad * grad.

class mxnet.optimizer.Updater(optimizer)[source]

Bases: object

Updater for kvstore.

Methods

get_states([dump_optimizer])

Gets updater states.

set_states(states)

Sets updater states.

sync_state_context(state, context)

sync state context.

get_states(dump_optimizer=False)[source]

Gets updater states.

Parameters

dump_optimizer (bool, default False) – Whether to also save the optimizer itself. This would also save optimizer information such as learning rate and weight decay schedules.

set_states(states)[source]

Sets updater states.

sync_state_context(state, context)[source]

sync state context.

class mxnet.optimizer.ccSGD(*args, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.SGD

[DEPRECATED] Same as SGD. Left here for backward compatibility.

mxnet.optimizer.create(name, **kwargs)

Instantiates an optimizer with a given name and kwargs.

Note

We can use the alias create for Optimizer.create_optimizer.

Parameters
  • name (str) – Name of the optimizer. Should be the name of a subclass of Optimizer. Case insensitive.

  • kwargs (dict) – Parameters for the optimizer.

Returns

An instantiated optimizer.

Return type

Optimizer

Examples

>>> sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
>>> type(sgd)
<class 'mxnet.optimizer.SGD'>
>>> adam = mx.optimizer.create('adam', learning_rate=.1)
>>> type(adam)
<class 'mxnet.optimizer.Adam'>
mxnet.optimizer.get_updater(optimizer)[source]

Returns a closure of the updater needed for kvstore.

Parameters

optimizer (Optimizer) – The optimizer.

Returns

updater – The closure of the updater.

Return type

function

mxnet.optimizer.register(klass)

Registers a new optimizer.

Once an optimizer is registered, we can create an instance of this optimizer with create_optimizer later.

Examples

>>> @mx.optimizer.Optimizer.register
... class MyOptimizer(mx.optimizer.Optimizer):
...     pass
>>> optim = mx.optimizer.Optimizer.create_optimizer('MyOptimizer')
>>> print(type(optim))
<class '__main__.MyOptimizer'>