# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-instance-attributes, too-many-arguments, protected-access
# pylint: disable=too-many-public-methods
"""A `BucketingModule` implement the `BaseModule` API, and allows multiple
symbols to be used depending on the `bucket_key` provided by each different
mini-batch of data.
"""
import logging
import warnings
from .. import context as ctx
from ..initializer import Uniform
from .base_module import BaseModule, _check_input_names
from .module import Module
[docs]class BucketingModule(BaseModule):
"""This module helps to deal efficiently with varying-length inputs.
Parameters
----------
sym_gen : function
A function when called with a bucket key, returns a triple
``(symbol, data_names, label_names)``.
default_bucket_key : str (or any python object)
The key for the default bucket.
logger : Logger
context : Context or list of Context
Defaults to ``mx.cpu()``
work_load_list : list of number
Defaults to ``None``, indicating uniform workload.
fixed_param_names: list of str
Defaults to ``None``, indicating no network parameters are fixed.
state_names : list of str
States are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
group2ctxs : dict of str to context or list of context,
or list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
compression_params : dict
Specifies type of gradient compression and additional arguments depending
on the type of compression being used. For example, 2bit compression requires a threshold.
Arguments would then be {'type':'2bit', 'threshold':0.5}
See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
"""
def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None, group2ctxs=None,
compression_params=None):
super(BucketingModule, self).__init__(logger=logger)
assert default_bucket_key is not None
self._default_bucket_key = default_bucket_key
self._sym_gen = sym_gen
symbol, data_names, label_names = sym_gen(default_bucket_key)
data_names = list(data_names) if data_names is not None else []
label_names = list(label_names) if label_names is not None else []
state_names = list(state_names) if state_names is not None else []
fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else []
_check_input_names(symbol, data_names, "data", True)
_check_input_names(symbol, label_names, "label", False)
_check_input_names(symbol, state_names, "state", True)
_check_input_names(symbol, fixed_param_names, "fixed_param", True)
self._compression_params = compression_params
self._fixed_param_names = fixed_param_names
self._state_names = state_names
self._context = context
self._work_load_list = work_load_list
self._group2ctxs = group2ctxs
self._buckets = {}
self._curr_module = None
self._curr_bucket_key = None
self._params_dirty = False
self._monitor = None
def _reset_bind(self):
"""Internal utility function to reset binding."""
self.binded = False
self._buckets = {}
self._curr_module = None
self._curr_bucket_key = None
@property
def data_names(self):
"""A list of names for data required by this module."""
if self.binded:
return self._curr_module.data_names
else:
_, data_names, _ = self._sym_gen(self._default_bucket_key)
return data_names
@property
def output_names(self):
"""A list of names for the outputs of this module."""
if self.binded:
return self._curr_module.output_names
else:
symbol, _, _ = self._sym_gen(self._default_bucket_key)
return symbol.list_outputs()
@property
def data_shapes(self):
"""Get data shapes.
Returns
-------
A list of `(name, shape)` pairs.
"""
assert self.binded
return self._curr_module.data_shapes
@property
def label_shapes(self):
"""Get label shapes.
Returns
-------
A list of `(name, shape)` pairs.
The return value could be ``None`` if the module does not need labels,
or if the module is not bound for training (in this case, label information
is not available).
"""
assert self.binded
return self._curr_module.label_shapes
@property
def output_shapes(self):
"""Gets output shapes.
Returns
-------
A list of `(name, shape)` pairs.
"""
assert self.binded
return self._curr_module.output_shapes
[docs] def get_params(self):
"""Gets current parameters.
Returns
-------
`(arg_params, aux_params)`
A pair of dictionaries each mapping parameter names to NDArray values.
"""
assert self.binded and self.params_initialized
self._curr_module._params_dirty = self._params_dirty
params = self._curr_module.get_params()
self._params_dirty = False
return params
[docs] def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True,
allow_extra=False):
"""Assigns parameters and aux state values.
Parameters
----------
arg_params : dict
Dictionary of name to value (`NDArray`) mapping.
aux_params : dict
Dictionary of name to value (`NDArray`) mapping.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
allow_extra : boolean, optional
Whether allow extra parameters that are not needed by symbol.
If this is True, no error will be thrown when arg_params or aux_params
contain extra parameters that is not needed by the executor.
Examples
--------
>>> # An example of setting module parameters.
>>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, n_epoch_load)
>>> mod.set_params(arg_params=arg_params, aux_params=aux_params)
"""
if not allow_missing:
self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params,
allow_missing=allow_missing, force_init=force_init)
return
if self.params_initialized and not force_init:
warnings.warn("Parameters already initialized and force_init=False. "
"set_params call ignored.", stacklevel=2)
return
self._curr_module.set_params(arg_params, aux_params, allow_missing=allow_missing,
force_init=force_init, allow_extra=allow_extra)
# because we didn't update self._arg_params, they are dirty now.
self._params_dirty = True
self.params_initialized = True
[docs] def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
allow_missing=False, force_init=False, allow_extra=False):
"""Initializes parameters.
Parameters
----------
initializer : Initializer
arg_params : dict
Defaults to ``None``. Existing parameters. This has higher priority
than `initializer`.
aux_params : dict
Defaults to ``None``. Existing auxiliary states. This has higher priority
than `initializer`.
allow_missing : bool
Allow missing values in `arg_params` and `aux_params` (if not ``None``).
In this case, missing values will be filled with `initializer`.
force_init : bool
Defaults to ``False``.
allow_extra : boolean, optional
Whether allow extra parameters that are not needed by symbol.
If this is True, no error will be thrown when arg_params or aux_params
contain extra parameters that is not needed by the executor.
"""
if self.params_initialized and not force_init:
return
assert self.binded, 'call bind before initializing the parameters'
self._curr_module.init_params(initializer=initializer, arg_params=arg_params,
aux_params=aux_params, allow_missing=allow_missing,
force_init=force_init, allow_extra=allow_extra)
self._params_dirty = False
self.params_initialized = True
[docs] def get_states(self, merge_multi_context=True):
"""Gets states from all devices.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the states
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
list of NDArrays or list of list of NDArrays
If `merge_multi_context` is ``True``, it is like ``[out1, out2]``. Otherwise, it
is like ``[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]``. All the output
elements are `NDArray`.
"""
assert self.binded and self.params_initialized
return self._curr_module.get_states(merge_multi_context=merge_multi_context)
[docs] def set_states(self, states=None, value=None):
"""Sets value for states. Only one of states & values can be specified.
Parameters
----------
states : list of list of NDArrays
Source states arrays formatted like ``[[state1_dev1, state1_dev2],
[state2_dev1, state2_dev2]]``.
value : number
A single scalar value for all state arrays.
"""
assert self.binded and self.params_initialized
self._curr_module.set_states(states, value)
[docs] def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
grad_req='write'):
"""Binding for a `BucketingModule` means setting up the buckets and binding the
executor for the default bucket key. Executors corresponding to other keys are
bound afterwards with `switch_bucket`.
Parameters
----------
data_shapes : list of (str, tuple)
This should correspond to the symbol for the default bucket.
label_shapes : list of (str, tuple)
This should correspond to the symbol for the default bucket.
for_training : bool
Default is ``True``.
inputs_need_grad : bool
Default is ``False``.
force_rebind : bool
Default is ``False``.
shared_module : BucketingModule
Default is ``None``. This value is currently not used.
grad_req : str, list of str, dict of str to str
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(default to 'write').
Can be specified globally (str) or for each argument (list, dict).
bucket_key : str (or any python object)
bucket key for binding. by default use the default_bucket_key
"""
# in case we already initialized params, keep it
if self.params_initialized:
arg_params, aux_params = self.get_params()
# force rebinding is typically used when one want to switch from
# training to prediction phase.
if force_rebind:
self._reset_bind()
if self.binded:
self.logger.warning('Already bound, ignoring bind()')
return
assert shared_module is None, 'shared_module for BucketingModule is not supported'
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
self.binded = True
symbol, data_names, label_names = self._sym_gen(self._default_bucket_key)
module = Module(symbol, data_names, label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names,
group2ctxs=self._group2ctxs,
compression_params=self._compression_params)
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None, grad_req=grad_req)
self._curr_module = module
self._curr_bucket_key = self._default_bucket_key
self._buckets[self._default_bucket_key] = module
# copy back saved params, if already initialized
if self.params_initialized:
self.set_params(arg_params, aux_params)
[docs] def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
"""Switches to a different bucket. This will change ``self.curr_module``.
Parameters
----------
bucket_key : str (or any python object)
The key of the target bucket.
data_shapes : list of (str, tuple)
Typically ``data_batch.provide_data``.
label_shapes : list of (str, tuple)
Typically ``data_batch.provide_label``.
"""
assert self.binded, 'call bind before switching bucket'
if not bucket_key in self._buckets:
symbol, data_names, label_names = self._sym_gen(bucket_key)
module = Module(symbol, data_names, label_names,
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names,
group2ctxs=self._group2ctxs,
compression_params=self._compression_params)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
if self._monitor is not None:
module.install_monitor(self._monitor)
self._buckets[bucket_key] = module
self._curr_module = self._buckets[bucket_key]
self._curr_bucket_key = bucket_key
[docs] def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),),
force_init=False):
"""Installs and initializes optimizers.
Parameters
----------
kvstore : str or KVStore
Defaults to `'local'`.
optimizer : str or Optimizer
Defaults to `'sgd'`
optimizer_params : dict
Defaults to `(('learning_rate', 0.01),)`. The default value is not a dictionary,
just to avoid pylint warning of dangerous default values.
force_init : bool
Defaults to ``False``, indicating whether we should force re-initializing the
optimizer in the case an optimizer is already installed.
"""
assert self.binded and self.params_initialized
if self.optimizer_initialized and not force_init:
self.logger.warning('optimizer already initialized, ignoring.')
return
self._curr_module.init_optimizer(kvstore, optimizer, optimizer_params,
force_init=force_init)
for mod in self._buckets.values():
if mod is not self._curr_module:
mod.borrow_optimizer(self._curr_module)
self.optimizer_initialized = True
[docs] def prepare(self, data_batch):
"""Prepares a data batch for forward.
Parameters
----------
data_batch : DataBatch
"""
# perform bind if haven't done so
assert self.binded and self.params_initialized
bucket_key = data_batch.bucket_key
original_bucket_key = self._curr_bucket_key
data_shapes = data_batch.provide_data
label_shapes = data_batch.provide_label
self.switch_bucket(bucket_key, data_shapes, label_shapes)
# switch back
self.switch_bucket(original_bucket_key, None, None)
[docs] def forward(self, data_batch, is_train=None):
"""Forward computation.
Parameters
----------
data_batch : DataBatch
is_train : bool
Defaults to ``None``, in which case `is_train` is take as ``self.for_training``.
"""
assert self.binded and self.params_initialized
self.switch_bucket(data_batch.bucket_key, data_batch.provide_data,
data_batch.provide_label)
self._curr_module.forward(data_batch, is_train=is_train)
[docs] def backward(self, out_grads=None):
"""Backward computation."""
assert self.binded and self.params_initialized
self._curr_module.backward(out_grads=out_grads)
[docs] def update(self):
"""Updates parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
self._params_dirty = True
self._curr_module.update()
[docs] def get_outputs(self, merge_multi_context=True):
"""Gets outputs from a previous forward computation.
Parameters
----------
merge_multi_context : bool
Defaults to ``True``. In the case when data-parallelism is used, the outputs
will be collected from multiple devices. A ``True`` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
list of numpy arrays or list of list of numpy arrays
If `merge_multi_context` is ``True``, it is like ``[out1, out2]``. Otherwise, it
is like ``[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]``. All the output
elements are numpy arrays.
"""
assert self.binded and self.params_initialized
return self._curr_module.get_outputs(merge_multi_context=merge_multi_context)
[docs] def update_metric(self, eval_metric, labels):
"""Evaluates and accumulates evaluation metric on outputs of the last forward computation.
Parameters
----------
eval_metric : EvalMetric
labels : list of NDArray
Typically ``data_batch.label``.
"""
assert self.binded and self.params_initialized
self._curr_module.update_metric(eval_metric, labels)
@property
def symbol(self):
"""The symbol of the current bucket being used."""
assert self.binded
return self._curr_module.symbol
[docs] def install_monitor(self, mon):
"""Installs monitor on all executors """
assert self.binded
self._monitor = mon
for mod in self._buckets.values():
mod.install_monitor(mon)