Source code for mxnet.optimizer.contrib
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
"""Contrib optimizers."""
from ..ndarray import (NDArray, clip, contrib, mean, sqrt, square, zeros)
from .optimizer import Optimizer
# convenience wrapper for Optimizer.Register
register = Optimizer.register # pylint: disable=invalid-name
__all__ = ['GroupAdaGrad']
@register
[docs]class GroupAdaGrad(Optimizer):
"""Adagrad optimizer with row-wise learning rates.
This class implements the AdaGrad optimizer described in *Adaptive
Subgradient Methods for Online Learning and Stochastic Optimization*, and
available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but
uses only a single learning rate for every row of the parameter array.
This optimizer updates each weight by::
grad = clip(grad * rescale_grad, clip_gradient)
history += mean(square(grad), axis=1, keepdims=True)
div = grad / sqrt(history + float_stable_eps)
weight -= div * lr
Weights are updated lazily if the gradient is sparse.
For details of the update algorithm see
:class:`~mxnet.ndarray.contrib.group_adagrad_update`.
This optimizer accepts the following parameters in addition to those
accepted by :class:`.Optimizer`. Weight decay is not supported.
Parameters
----------
eps: float, optional
Initial value of the history accumulator. Avoids division by 0.
"""
def __init__(self, eps=1e-5, **kwargs):
super(GroupAdaGrad, self).__init__(**kwargs)
self.float_stable_eps = eps
def create_state(self, index, weight):
assert len(weight.shape) == 2
history = zeros(
(weight.shape[0], 1), weight.context, stype=weight.stype)
return history
def update(self, index, weight, grad, state):
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
assert wd == 0, 'Weight decay is not supported for GroupAdaGrad'
is_sparse = grad.stype == 'row_sparse'
if is_sparse:
kwargs = {
'epsilon': self.float_stable_eps,
'rescale_grad': self.rescale_grad
}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
contrib.group_adagrad_update(
weight,
grad,
state,
out=weight,
lr=lr,
**kwargs)
else:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
state[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(state + self.float_stable_eps)
weight[:] -= div