Source code for mxnet.executor_manager

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-locals, too-many-arguments, too-many-statements
"""Executor manager."""
from __future__ import absolute_import

import logging
import numpy as np

from .base import mx_real_t
from . import ndarray as nd
from .context import cpu
from .io import DataDesc

def _split_input_slice(batch_size, work_load_list):
    """Get input slice from the input shape.

    Parameters
    ----------
    batch_size : int
        The number of samples in a mini-batch.
    work_load_list : list of float or int, optional
        The list of work load for different devices,
        in the same order as `ctx`.

    Returns
    -------
    slices : list of slice
        The split slices to get a specific slice.

    Raises
    ------
    ValueError
        In case of too many splits, leading to some empty slices.
    """
    total_work_load = sum(work_load_list)
    batch_num_list = [round(work_load * batch_size / total_work_load)
                      for work_load in work_load_list]
    batch_num_sum = sum(batch_num_list)
    if batch_num_sum < batch_size:
        batch_num_list[-1] += batch_size - batch_num_sum
    slices = []
    end = 0
    for batch_num in batch_num_list:
        begin = int(min((end, batch_size)))
        end = int(min((begin + batch_num, batch_size)))
        if begin >= end:
            raise ValueError('Too many slices. Some splits are empty.')
        slices.append(slice(begin, end))
    return slices

def _check_arguments(symbol):
    """Check the argument names of symbol.
    This function checks the duplication of arguments in Symbol.
    The check is done for feedforward net for now.

    Parameters
    ----------
    symbol : Symbol
        The network configuration.
    """
    arg_set = set()
    arg_names = symbol.list_arguments()
    for name in arg_names:
        if name in arg_set:
            raise ValueError(('Find duplicated argument name \"%s\", ' +
                              'please make the weight name non-duplicated(using name arguments), ' +
                              'arguments are %s') % (name, str(arg_names)))
        arg_set.add(name)

    aux_set = set()
    aux_names = symbol.list_auxiliary_states()
    for name in aux_names:
        if name in aux_set:
            raise ValueError(
                ('Find duplicated auxiliary param name \"%s\", ' +
                 'please make the weight name non-duplicated(using name arguments), ' +
                 'arguments are %s, auxiliary params are %s'
                ) % (name, str(arg_names), str(aux_names)))
        aux_set.add(name)

def _load_general(data, targets):
    """Load a list of arrays into a list of arrays specified by slices."""
    for d_src, d_targets in zip(data, targets):
        if isinstance(d_targets, nd.NDArray):
            d_src.copyto(d_targets)
        else:
            assert d_targets[-1][0].stop == d_src.shape[0], \
                "Batch size miss match. Expected %d, got %d"%( \
                    d_targets[-1][0].stop, d_src.shape[0])
            for slice_idx, d_dst in d_targets:
                d_src[slice_idx].copyto(d_dst)

def _load_data(batch, targets):
    """Load data into sliced arrays."""
    _load_general(batch.data, targets)

def _load_label(batch, targets):
    """Load label into sliced arrays."""
    _load_general(batch.label, targets)

# pylint: disable=too-many-branches
def _bind_exec(sym, ctx, input_shapes, param_names, need_grad=False,
               base_exec=None, shared_data_arrays=None, input_types=None, logger=logging):
    """bind executor for bucketing, potentially sharing data with an existing executor."""
    arg_shape, _, aux_shape = sym.infer_shape(**input_shapes)
    assert(arg_shape is not None)
    if input_types is None:
        input_types = {k: mx_real_t for k in input_shapes.keys()}
    arg_types, _, aux_types = sym.infer_type(**input_types)
    assert(arg_types is not None)

    arg_arrays = []
    grad_arrays = {} if need_grad is not False else None

    arg_names = sym.list_arguments()

    if need_grad is False:
        need_grad = set()
    elif need_grad is True:
        need_grad = set(arg_names) - set(input_shapes.keys())
    elif isinstance(need_grad, set):
        pass
    else:
        raise AssertionError("need_grad must be boolean or set.")
    grad_req = {name:('write' if name in need_grad else 'null') for name in arg_names}


    # create or borrow arguments and gradients
    for i, name in enumerate(arg_names):
        if not name in param_names:
            # data or label
            if shared_data_arrays is not None and \
                    name in shared_data_arrays:
                arg_arr = shared_data_arrays[name]

                if np.prod(arg_arr.shape) >= np.prod(arg_shape[i]):
                    # good, we can share this memory
                    assert(arg_types[i] == arg_arr.dtype)
                    arg_arr = arg_arr.reshape(arg_shape[i])
                else:
                    logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape[i])) +
                                   (', which is larger than already allocated ') +
                                   ('shape %s' % (arg_arr.shape,)) +
                                   ('. Need to re-allocate. Consider putting ') +
                                   ('default_bucket_key to be the bucket taking the largest ') +
                                   ('input for better memory sharing.'))
                    arg_arr = nd.zeros(arg_shape[i], ctx, dtype=arg_types[i])

                    # replace existing shared array because the new one is bigger
                    shared_data_arrays[name] = arg_arr
            else:
                arg_arr = nd.zeros(arg_shape[i], ctx, dtype=arg_types[i])
                if shared_data_arrays is not None:
                    shared_data_arrays[name] = arg_arr

            arg_arrays.append(arg_arr)
        else:
            # model parameter
            if base_exec is None:
                arg_arr = nd.zeros(arg_shape[i], ctx, dtype=arg_types[i])
                if name in need_grad:
                    grad_arr = nd.zeros(arg_shape[i], ctx, dtype=arg_types[i])
                    grad_arrays[name] = grad_arr
            else:
                arg_arr = base_exec.arg_dict[name]
                assert arg_arr.shape == arg_shape[i]
                assert arg_arr.dtype == arg_types[i]
                if name in need_grad:
                    grad_arrays[name] = base_exec.grad_dict[name]
            arg_arrays.append(arg_arr)

    # create or borrow aux variables
    if base_exec is None:
        aux_arrays = [nd.zeros(s, ctx, dtype=t) for s, t in zip(aux_shape, aux_types)]
    else:
        for i, a in enumerate(base_exec.aux_arrays):
            assert aux_shape[i] == a.shape
            assert aux_types[i] == a.dtype

        aux_arrays = [a for a in base_exec.aux_arrays]

    executor = sym.bind(ctx=ctx, args=arg_arrays, args_grad=grad_arrays,
                        aux_states=aux_arrays,
                        grad_req=grad_req, shared_exec=base_exec)
    return executor

[docs]class DataParallelExecutorGroup(object): """A group of executors living on different devices, for data parallelization. Parameters ---------- sym: Symbol The network configuration. arg_names: list of str Equals `sym.list_arguments()` param_names: list of str List of names of all trainable parameters. ctx: list of Context List of devices for training (data parallelization). slices: list of int Describes how the data parallelization splits data into different devices. train_data: DataIter (or DataBatch) The dataset for training. It could be any object with `provide_data` and `provide_label` properties. Loading of actual data is not necessarily needed at this stage. shared_grop: DataParallelExecutorGroup An existing executor group, if to share parameters with it. """ def __init__(self, sym, arg_names, param_names, ctx, slices, train_data, shared_group=None): # make sure the architecture is valid _check_arguments(sym) if shared_group is None: self.shared_data_arrays = [{} for _ in ctx] else: self.shared_data_arrays = shared_group.shared_data_arrays self.data_names = [x[0] for x in train_data.provide_data] self.label_names = [x[0] for x in train_data.provide_label] self.aux_names = sym.list_auxiliary_states() self.param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names] self.param_names = [arg_names[i] for i in self.param_idx] self.train_execs = [] for i, ctxi in enumerate(ctx): data_shapes = {} data_types = {} for x in train_data.provide_data + train_data.provide_label: data_shapes[x[0]] = tuple([slices[i].stop - slices[i].start] + list(x[1][1:])) if isinstance(x, DataDesc): data_types[x.name] = x.dtype else: data_types[x[0]] = mx_real_t shared_exec = None if shared_group is None else shared_group.train_execs[i] train_exec = _bind_exec(sym, ctxi, data_shapes, self.param_names, need_grad=True, base_exec=shared_exec, shared_data_arrays=self.shared_data_arrays[i], input_types=data_types) self.train_execs.append(train_exec) # data structure self.data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] for name in self.data_names] self.label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] for name in self.label_names] self.param_arrays = [[e.arg_arrays[i] for e in self.train_execs] for i in self.param_idx] self.grad_arrays = [[e.grad_arrays[i] for e in self.train_execs] for i in self.param_idx] self.aux_arrays = [[e.aux_arrays[i] for e in self.train_execs] for i in range(len(self.aux_names))] self.slices = slices
[docs] def load_data_batch(self, data_batch): """Load data and labels into arrays.""" _load_data(data_batch, self.data_arrays) _load_label(data_batch, self.label_arrays)
[docs] def forward(self, is_train=False): """Perform a forward pass on each executor.""" for texec in self.train_execs: texec.forward(is_train=is_train)
[docs] def backward(self): """Perform a backward pass on each executor.""" for texec in self.train_execs: texec.backward()
[docs] def update_metric(self, metric, labels, pre_sliced=False): """Update evaluation metric with label and current outputs.""" for current_exec, (texec, islice) in enumerate(zip(self.train_execs, self.slices)): if not pre_sliced: labels_slice = [label[islice] for label in labels] else: labels_slice = labels[current_exec] metric.update(labels_slice, texec.outputs)
[docs]class DataParallelExecutorManager(object): """ Helper class to manage multiple executors for data parallelism. Parameters ---------- symbol : Symbol Output symbol. ctx : list of Context Devices to run on. param_names: list of str Name of all trainable parameters of the network. arg_names: list of str Name of all arguments of the network. aux_names: list of str Name of all auxiliary states of the network. train_data : DataIter Training data iterator. work_load_list : list of float or int, optional The list of work load for different devices, in the same order as ctx. logger : logging logger When not specified, default logger will be used. sym_gen : A function that generate new Symbols depending on different input shapes. Used only for bucketing. """ def __init__(self, symbol, ctx, train_data, arg_names, param_names, aux_names, work_load_list=None, logger=None, sym_gen=None): if logger is None: logger = logging # preparation num_device = len(ctx) logger.info('Start training with %s', str(ctx)) if work_load_list is None: work_load_list = [1] * num_device assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \ "Invalid settings for work load. " slices = _split_input_slice(train_data.batch_size, work_load_list) self.slices = slices self.arg_names = arg_names self.param_names = param_names self.aux_names = aux_names self.ctx = ctx self.execgrp = DataParallelExecutorGroup(symbol, self.arg_names, self.param_names, self.ctx, self.slices, train_data) self.symbol = symbol self.sym_gen = sym_gen self.curr_execgrp = None # this is set when data is loaded if self.sym_gen is not None: self.execgrp_bucket = {train_data.default_bucket_key: self.execgrp}
[docs] def install_monitor(self, monitor): """Install monitor on all executors.""" if self.sym_gen is not None: raise NotImplementedError("Monitoring is not implemented for bucketing") for train_exec in self.execgrp.train_execs: monitor.install(train_exec)
[docs] def set_params(self, arg_params, aux_params): """Set parameter and aux values. Parameters ---------- arg_params : list of NDArray Source parameter arrays aux_params : list of NDArray Source aux arrays. """ for texec in self.execgrp.train_execs: texec.copy_params_from(arg_params, aux_params)
[docs] def copy_to(self, arg_params, aux_params): """ Copy data from each executor to ```arg_params`` and ``aux_params``. Parameters ---------- arg_params : list of NDArray Target parameter arrays. aux_params : list of NDArray Target aux arrays. Notes ----- - This function will inplace update the NDArrays in arg_params and aux_params. """ for name, block in zip(self.param_names, self.param_arrays): weight = sum(w.copyto(cpu()) for w in block) / len(block) weight.astype(arg_params[name].dtype).copyto(arg_params[name]) for name, block in zip(self.aux_names, self.aux_arrays): weight = sum(w.copyto(cpu()) for w in block) / len(block) weight.astype(aux_params[name].dtype).copyto(aux_params[name])
@property def param_arrays(self): """Shared parameter arrays.""" # param arrays should be shared by all executor groups return self.execgrp.param_arrays @property def grad_arrays(self): """Shared gradient arrays.""" # grad arrays should be shared by all executor groups return self.execgrp.grad_arrays @property def aux_arrays(self): """Shared aux states.""" # aux arrays are also shared by all executor groups return self.execgrp.aux_arrays
[docs] def load_data_batch(self, data_batch): """Load data and labels into arrays.""" if self.sym_gen is not None: key = data_batch.bucket_key if key not in self.execgrp_bucket: # create new bucket entry symbol = self.sym_gen(key) execgrp = DataParallelExecutorGroup(symbol, self.arg_names, self.param_names, self.ctx, self.slices, data_batch, shared_group=self.execgrp) self.execgrp_bucket[key] = execgrp self.curr_execgrp = self.execgrp_bucket[key] else: self.curr_execgrp = self.execgrp self.curr_execgrp.load_data_batch(data_batch)
[docs] def forward(self, is_train=False): """Run forward on the current executor.""" self.curr_execgrp.forward(is_train=is_train)
[docs] def backward(self): """Run backward on the current executor.""" self.curr_execgrp.backward()
[docs] def update_metric(self, metric, labels, pre_sliced=False): """Update metric with the current executor.""" self.curr_execgrp.update_metric(metric, labels, pre_sliced)