mx.symbol.lamb_update_phase1¶

Description¶

Phase I of lamb update it performs the following operations and returns g:.

\begin{align}\begin{aligned}\begin{gather*} grad = grad * rescale_grad if (grad < -clip_gradient) then grad = -clip_gradient if (grad > clip_gradient) then grad = clip_gradient\\mean = beta1 * mean + (1 - beta1) * grad; variance = beta2 * variance + (1. - beta2) * grad ^ 2;\\if (bias_correction) then mean_hat = mean / (1. - beta1^t); var_hat = var / (1 - beta2^t); g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight; else g = mean / (var_data^(1/2) + epsilon) + wd * weight; \end{gather*}\end{aligned}\end{align}

Usage¶

mx.symbol.lamb_update_phase1(...)


Arguments¶

Argument

Description

weight

NDArray-or-Symbol.

Weight

grad

NDArray-or-Symbol.

mean

NDArray-or-Symbol.

Moving mean

var

NDArray-or-Symbol.

Moving variance

beta1

float, optional, default=0.899999976.

The decay rate for the 1st moment estimates.

beta2

float, optional, default=0.999000013.

The decay rate for the 2nd moment estimates.

epsilon

float, optional, default=9.99999997e-07.

A small constant for numerical stability.

t

int, required.

Index update count.

bias.correction

boolean, optional, default=1.

Whether to use bias correction.

wd

float, required.

Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.

rescale.grad

float, optional, default=1.

clip.gradient

float, optional, default=-1.

name
out The result mx.symbol