mxnet.executor_manager

Executor manager.

Classes

DataParallelExecutorGroup(sym, arg_names, …)

A group of executors living on different devices, for data parallelization.

DataParallelExecutorManager(symbol, ctx, …)

Helper class to manage multiple executors for data parallelism.

class mxnet.executor_manager.DataParallelExecutorGroup(sym, arg_names, param_names, ctx, slices, train_data, shared_group=None)[source]

Bases: object

A group of executors living on different devices, for data parallelization.

Parameters
  • sym (Symbol) – The network configuration.

  • arg_names (list of str) – Equals sym.list_arguments()

  • param_names (list of str) – List of names of all trainable parameters.

  • ctx (list of Context) – List of devices for training (data parallelization).

  • slices (list of int) – Describes how the data parallelization splits data into different devices.

  • train_data (DataIter (or DataBatch)) – The dataset for training. It could be any object with provide_data and provide_label properties. Loading of actual data is not necessarily needed at this stage.

  • shared_grop (DataParallelExecutorGroup) – An existing executor group, if to share parameters with it.

Methods

backward()

Perform a backward pass on each executor.

forward([is_train])

Perform a forward pass on each executor.

load_data_batch(data_batch)

Load data and labels into arrays.

update_metric(metric, labels[, pre_sliced])

Update evaluation metric with label and current outputs.

backward()[source]

Perform a backward pass on each executor.

forward(is_train=False)[source]

Perform a forward pass on each executor.

load_data_batch(data_batch)[source]

Load data and labels into arrays.

update_metric(metric, labels, pre_sliced=False)[source]

Update evaluation metric with label and current outputs.

class mxnet.executor_manager.DataParallelExecutorManager(symbol, ctx, train_data, arg_names, param_names, aux_names, work_load_list=None, logger=None, sym_gen=None)[source]

Bases: object

Helper class to manage multiple executors for data parallelism.

Parameters
  • symbol (Symbol) – Output symbol.

  • ctx (list of Context) – Devices to run on.

  • param_names (list of str) – Name of all trainable parameters of the network.

  • arg_names (list of str) – Name of all arguments of the network.

  • aux_names (list of str) – Name of all auxiliary states of the network.

  • train_data (DataIter) – Training data iterator.

  • work_load_list (list of float or int, optional) – The list of work load for different devices, in the same order as ctx.

  • logger (logging logger) – When not specified, default logger will be used.

  • sym_gen (A function that generate new Symbols depending on different) – input shapes. Used only for bucketing.

Attributes

aux_arrays

Shared aux states.

grad_arrays

Shared gradient arrays.

param_arrays

Shared parameter arrays.

Methods

backward()

Run backward on the current executor.

copy_to(arg_params, aux_params)

Copy data from each executor to `arg_params and aux_params.

forward([is_train])

Run forward on the current executor.

install_monitor(monitor)

Install monitor on all executors.

load_data_batch(data_batch)

Load data and labels into arrays.

set_params(arg_params, aux_params)

Set parameter and aux values.

update_metric(metric, labels[, pre_sliced])

Update metric with the current executor.

property aux_arrays

Shared aux states.

backward()[source]

Run backward on the current executor.

copy_to(arg_params, aux_params)[source]

Copy data from each executor to `arg_params and aux_params.

Parameters
  • arg_params (list of NDArray) – Target parameter arrays.

  • aux_params (list of NDArray) – Target aux arrays.

Notes

  • This function will inplace update the NDArrays in arg_params and aux_params.

forward(is_train=False)[source]

Run forward on the current executor.

property grad_arrays

Shared gradient arrays.

install_monitor(monitor)[source]

Install monitor on all executors.

load_data_batch(data_batch)[source]

Load data and labels into arrays.

property param_arrays

Shared parameter arrays.

set_params(arg_params, aux_params)[source]

Set parameter and aux values.

Parameters
  • arg_params (list of NDArray) – Source parameter arrays

  • aux_params (list of NDArray) – Source aux arrays.

update_metric(metric, labels, pre_sliced=False)[source]

Update metric with the current executor.