Source code for mxnet.optimizer.updater
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Updater class."""
from __future__ import absolute_import
import pickle
import numpy
from ..base import py_str
from ..ndarray import NDArray
from ..profiler import scope as profiler_scope
from ..util import is_np_array
from .utils import _as_classic
__all__ = ['Updater', 'get_updater']
[docs]class Updater(object):
"""Updater for kvstore."""
def __init__(self, optimizer):
self.optimizer = optimizer
self.states = {}
self.states_synced = {}
self.aggregate_updates = optimizer.aggregate_num > 1
def __call__(self, index, grad, weight):
"""Updates weight given gradient and index."""
allow_np = self.optimizer.allow_np_array if hasattr(self.optimizer, "allow_np_array") else is_np_array()
if not isinstance(index, (list, tuple)):
indices = [index]
grads = [_as_classic(grad, allow_np)]
weights = [_as_classic(weight, allow_np)]
else:
indices = index
grads = _as_classic(grad, allow_np)
weights = _as_classic(weight, allow_np)
if weights:
self.optimizer._set_current_device(weights[0].context.device_id)
for i, idx in enumerate(indices):
# convert ctypes.char_p.value back to python str if needed
if isinstance(idx, bytes):
indices[i] = py_str(idx)
idx = indices[i]
if idx not in self.states:
with profiler_scope("updater:optimizer_state"):
self.states[idx] = self.optimizer.create_state_multi_precision(idx, weights[i])
self.states_synced[idx] = True
elif not self.states_synced[idx]:
self.states[idx] = \
self.sync_state_context(self.states[idx], weights[i].context)
self.states_synced[idx] = True
if self.aggregate_updates:
# segregate values based on type
if self.optimizer.aggregate_num is not numpy.inf:
type_map = {}
for i, w, g in zip(indices, weights, grads):
if w.dtype in type_map:
type_map[w.dtype].append((i, w, g))
else:
type_map[w.dtype] = [(i, w, g)]
for idx in type_map:
current_index = 0
indices, weights, grads = zip(*type_map[idx])
while current_index < len(indices):
states = []
step = min(self.optimizer.aggregate_num, len(indices) - current_index)
for j in range(step):
states.append(self.states[indices[current_index + j]])
self.optimizer.update_multi_precision(
indices[current_index:current_index + self.optimizer.aggregate_num],
weights[current_index:current_index + self.optimizer.aggregate_num],
grads[current_index:current_index + self.optimizer.aggregate_num],
states)
current_index += self.optimizer.aggregate_num
else:
states = [self.states[i] for i in indices]
self.optimizer.update_multi_precision(indices, weights, grads, states)
else:
for i, w, g in zip(indices, weights, grads):
self.optimizer.update_multi_precision([i], [w], [g], [self.states[i]])
[docs] def sync_state_context(self, state, context):
"""sync state context."""
if isinstance(state, NDArray):
return state.as_in_context(context)
elif isinstance(state, (tuple, list)):
synced_state = (self.sync_state_context(i, context) for i in state)
if isinstance(state, tuple):
return tuple(synced_state)
else:
return list(synced_state)
else:
return state
[docs] def set_states(self, states):
"""Sets updater states."""
states = pickle.loads(states)
if isinstance(states, tuple) and len(states) == 2:
self.states, self.optimizer = states
else:
self.states = states
self.states_synced = dict.fromkeys(self.states.keys(), False)
[docs] def get_states(self, dump_optimizer=False):
"""Gets updater states.
Parameters
----------
dump_optimizer : bool, default False
Whether to also save the optimizer itself. This would also save optimizer
information such as learning rate and weight decay schedules.
"""
return pickle.dumps((self.states, self.optimizer) if dump_optimizer else self.states)
[docs]def get_updater(optimizer):
"""Returns a closure of the updater needed for kvstore.
Parameters
----------
optimizer: Optimizer
The optimizer.
Returns
-------
updater: function
The closure of the updater.
"""
return Updater(optimizer)
Did this page help you?
Yes
No
Thanks for your feedback!