Source code for mxnet.optimizer.lamb
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Lamb optimizer."""
from __future__ import absolute_import
import numpy
from ..ndarray import (zeros, clip, sqrt, where, square, ones_like,
maximum, minimum)
from ..ndarray import (lamb_update_phase1, lamb_update_phase2,
mp_lamb_update_phase1, mp_lamb_update_phase2)
from ..ndarray.contrib import (multi_lamb_update, multi_mp_lamb_update)
from .optimizer import Optimizer, register
__all__ = ['LAMB']
[docs]@register
class LAMB(Optimizer):
"""LAMB Optimizer.
Referenced from 'Large Batch Optimization for Deep Learning: Training BERT in 76 minutes'
(https://arxiv.org/pdf/1904.00962.pdf)
Parameters
----------
learning_rate : float, default 0.001
The initial learning rate. If None, the optimization will use the
learning rate from ``lr_scheduler``. If not None, it will overwrite
the learning rate in ``lr_scheduler``. If None and ``lr_scheduler``
is also None, then it will be set to 0.01 by default.
beta1 : float, default 0.9
Exponential decay rate for the first moment estimates.
beta2 : float, default 0.999
Exponential decay rate for the second moment estimates.
epsilon : float, default 1e-6
Small value to avoid division by 0.
lower_bound : float, default None
Lower limit of norm of weight
upper_bound : float, default None
Upper limit of norm of weight
bias_correction : bool, default True
Whether or not to apply bias correction
aggregate_num : int, default 4
Number of weights to be aggregated in a list.
They are passed to the optimizer for a single optimization step.
In default, all the weights are aggregated.
use_fused_step : bool, default True
Whether or not to use fused kernels for optimizer.
When use_fused_step=False, step is called,
otherwise, fused_step is called.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=None, upper_bound=None, bias_correction=True,
aggregate_num=4, use_fused_step=True, **kwargs):
assert aggregate_num <= 45,\
'When use_fused_step is True, LAMB only supports aggregate_num <= 45,' \
' and receives {}'.format(aggregate_num)
super(LAMB, self).__init__(learning_rate=learning_rate,
aggregate_num=aggregate_num,
use_fused_step=use_fused_step,
**kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.bias_correction = bias_correction
[docs] def create_state(self, index, weight):
stype = weight.stype
return (zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype), # mean
zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype)) # var
[docs] def step(self, indices, weights, grads, states):
"""Perform a fused optimization step using gradients and states.
Fused kernel is used for update.
Parameters
----------
indices : list of int
List of unique indices of the parameters into the individual learning rates
and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
and `set_wd_mult()`, respectively.
weights : list of NDArray
List of parameters to be updated.
grads : list of NDArray
List of gradients of the objective with respect to this parameter.
states : List of any obj
List of state returned by `create_state()`.
"""
for index, weight, grad, state in zip(indices, weights, grads, states):
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]
# preprocess grad
grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
# update mean, var
mean, var = state
mean[:] *= self.beta1
mean[:] += (1. - self.beta1) * grad
var[:] *= self.beta2
var[:] += (1. - self.beta2) * square(grad)
r1 = weight.norm()
if self.lower_bound is not None:
r1 = maximum(r1, self.lower_bound)
if self.upper_bound is not None:
r1 = minimum(r1, self.upper_bound)
if self.bias_correction:
# apply bias correction
coef1 = 1. - self.beta1**t
coef2 = 1. - self.beta2**t
mean_hat = mean / coef1
var_hat = var / coef2
sqrt(var_hat, out=var_hat)
var_hat += self.epsilon
mean_hat /= var_hat
mean_hat += wd * weight
else:
mean_hat = sqrt(var)
mean_hat += self.epsilon
mean_hat[:] = mean / mean_hat
mean_hat += wd * weight
g = mean_hat
r2 = g.norm()
# calculate lamb_trust_ratio
ratio = r1 / r2
# becomes NaN if ratio == NaN or 0, otherwise 0
nan_or_zero = 1 - ratio / ratio
r = where(nan_or_zero, ones_like(ratio), ratio)
lr *= r
# update weight
g *= lr
weight[:] -= g
[docs] def fused_step(self, indices, weights, grads, states):
"""Perform a fused optimization step using gradients and states.
Fused kernel is used for update.
Parameters
----------
indices : list of int
List of unique indices of the parameters into the individual learning rates
and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
and `set_wd_mult()`, respectively.
weights : list of NDArray
List of parameters to be updated.
grads : list of NDArray
List of gradients of the objective with respect to this parameter.
states : List of any obj
List of state returned by `create_state()`.
"""
aggregate = self.aggregate_num > 1
for weight, grad in zip(weights, grads):
aggregate = (aggregate and
weight.stype == 'default' and
grad.stype == 'default')
self._update_count(indices)
lrs = self._get_lrs(indices)
wds = self._get_wds(indices)
if aggregate:
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'bias_correction': self.bias_correction,
'rescale_grad': self.rescale_grad}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if self.lower_bound:
kwargs['lower_bound'] = self.lower_bound
if self.upper_bound:
kwargs['upper_bound'] = self.upper_bound
step_counts = []
for index in indices:
step_counts.append(self._index_update_count[index])
multi_precision = self.multi_precision and weights[0].dtype == numpy.float16
if not multi_precision:
mean, var = list(zip(*states))
multi_lamb_update(weights, grads, mean, var,
out=weights, step_count=step_counts,
lrs=lrs, wds=wds, **kwargs)
else:
weights32, mean_var = list(zip(*states))
mean, var = list(zip(*mean_var))
multi_mp_lamb_update(weights, grads,
mean, var, weights32,
out=weights, step_count=step_counts,
lrs=lrs, wds=wds, **kwargs)
else:
for index, weight, grad, state in zip(indices, weights, grads, states):
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'bias_correction': self.bias_correction,
'rescale_grad': self.rescale_grad, 't': t}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
multi_precision = self.multi_precision and weight.dtype == numpy.float16
if multi_precision:
weight32 = state[0]
mean, var = state[1]
g = mp_lamb_update_phase1(weight, grad, mean, var, weight32, wd=wd, **kwargs)
kwargs = {}
if self.lower_bound:
kwargs['lower_bound'] = self.lower_bound
if self.upper_bound:
kwargs['upper_bound'] = self.upper_bound
r_1 = weight32.norm()
r_2 = g.norm()
mp_lamb_update_phase2(weight, g, r_1, r_2, weight32, lr=lr,
out=weight, **kwargs)
else:
mean, var = state
g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs)
kwargs = {}
if self.lower_bound:
kwargs['lower_bound'] = self.lower_bound
if self.upper_bound:
kwargs['upper_bound'] = self.upper_bound
r_1 = weight.norm()
r_2 = g.norm()
lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs)
[docs] def update_multi_precision(self, indices, weights, grads, states):
"""Override update_multi_precision.
"""
if self.use_fused_step:
self.update(indices, weights, grads, states)
else:
super(LAMB, self).update_multi_precision(indices, weights, grads, states)
Did this page help you?
Yes
No
Thanks for your feedback!