org.apache.mxnet.module

BaseModule

Related Docs: object BaseModule | package module

abstract class BaseModule extends AnyRef

The base class of a modules. A module represents a computation component. The design purpose of a module is that it abstract a computation "machine", that one can run forward, backward, update parameters, etc. We aim to make the APIs easy to use, especially in the case when we need to use imperative API to work with multiple modules (e.g. stochastic depth network).

A module has several states:

- Initial state. Memory is not allocated yet, not ready for computation yet. - Binded. Shapes for inputs, outputs, and parameters are all known, memory allocated, ready for computation. - Parameter initialized. For modules with parameters, doing computation before initializing the parameters might result in undefined outputs. - Optimizer installed. An optimizer can be installed to a module. After this, the parameters of the module can be updated according to the optimizer after gradients are computed (forward-backward).

In order for a module to interactive with others, a module should be able to report the following information in its raw stage (before binded)

And also the following richer information after binded:

When those intermediate-level API are implemented properly, the following high-level API will be automatically available for a module:

Linear Supertypes
AnyRef, Any
Known Subclasses
Ordering
  1. Alphabetic
  2. By inheritance
Inherited
  1. BaseModule
  2. AnyRef
  3. Any
  1. Hide All
  2. Show all
Learn more about member selection
Visibility
  1. Public
  2. All

Instance Constructors

  1. new BaseModule()

Abstract Value Members

  1. abstract def backward(outGrads: Array[NDArray] = null): Unit

    Backward computation.

    Backward computation.

    outGrads

    Gradient on the outputs to be propagated back. This parameter is only needed when bind is called on outputs that are not a loss function.

  2. abstract def bind(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]] = None, forTraining: Boolean = true, inputsNeedGrad: Boolean = false, forceRebind: Boolean = false, sharedModule: Option[BaseModule] = None, gradReq: String = "write"): Unit

    Bind the symbols to construct executors.

    Bind the symbols to construct executors. This is necessary before one can perform computation with the module.

    dataShapes

    Typically is DataIter.provideData.

    labelShapes

    Typically is DataIter.provideLabel.

    forTraining

    Default is True. Whether the executors should be bind for training.

    inputsNeedGrad

    Default is False. Whether the gradients to the input data need to be computed. Typically this is not needed. But this might be needed when implementing composition of modules.

    forceRebind

    Default is False. This function does nothing if the executors are already binded. But with this True, the executors will be forced to rebind.

    sharedModule

    Default is None. This is used in bucketing. When not None, the shared module essentially corresponds to a different bucket -- a module with different symbol but with the same sets of parameters (e.g. unrolled RNNs with different lengths).

    gradReq

    Requirement for gradient accumulation (globally). Can be 'write', 'add', or 'null' (default to 'write').

  3. abstract def dataNames: IndexedSeq[String]

  4. abstract def dataShapes: IndexedSeq[DataDesc]

  5. abstract def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit

    Forward computation.

    Forward computation.

    dataBatch

    Could be anything with similar API implemented.

    isTrain

    Default is None, which means isTrain takes the value of this.forTraining.

  6. abstract def getInputGrads(): IndexedSeq[IndexedSeq[NDArray]]

    Get the gradients to the inputs, computed in the previous backward computation.

    Get the gradients to the inputs, computed in the previous backward computation.

    returns

    In the case when data-parallelism is used, the grads will be collected from multiple devices. The results will look like [ [grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2] ], those NDArray might live on different devices.

  7. abstract def getInputGradsMerged(): IndexedSeq[NDArray]

    Get the gradients to the inputs, computed in the previous backward computation.

    Get the gradients to the inputs, computed in the previous backward computation.

    returns

    In the case when data-parallelism is used, the grads will be merged from multiple devices, as they look like from a single executor. The results will look like [grad1, grad2]

  8. abstract def getOutputs(): IndexedSeq[IndexedSeq[NDArray]]

    Get outputs of the previous forward computation.

    Get outputs of the previous forward computation.

    returns

    In the case when data-parallelism is used, the outputs will be collected from multiple devices. The results will look like [ [out1_dev1, out1_dev2], [out2_dev1, out2_dev2] ], those NDArray might live on different devices.

  9. abstract def getOutputsMerged(): IndexedSeq[NDArray]

    Get outputs of the previous forward computation.

    Get outputs of the previous forward computation.

    returns

    In the case when data-parallelism is used, the outputs will be merged from multiple devices, as they look like from a single executor. The results will look like [out1, out2]

  10. abstract def getParams: (Map[String, NDArray], Map[String, NDArray])

    Get parameters, those are potentially copies of the the actual parameters used to do computation on the device.

    Get parameters, those are potentially copies of the the actual parameters used to do computation on the device.

    returns

    (argParams, auxParams), a pair of dictionary of name to value mapping.

  11. abstract def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(), resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit

  12. abstract def initParams(initializer: Initializer = new Uniform(0.01f), argParams: Map[String, NDArray] = null, auxParams: Map[String, NDArray] = null, allowMissing: Boolean = false, forceInit: Boolean = false, allowExtra: Boolean = false): Unit

    Initialize the parameters and auxiliary states.

    Initialize the parameters and auxiliary states.

    initializer

    : Initializer Called to initialize parameters if needed. argParams : dict If not None, should be a dictionary of existing arg_params. Initialization will be copied from that. auxParams : dict If not None, should be a dictionary of existing aux_params. Initialization will be copied from that. allowMissing : bool If true, params could contain missing values, and the initializer will be called to fill those missing params. forceInit : bool If true, will force re-initialize even if already initialized. allowExtra : bool Whether allow extra parameters that are not needed by symbol. If this is True, no error will be thrown when argParams or auxParams contain extra parameters that is not needed by the executor.

  13. abstract def installMonitor(monitor: Monitor): Unit

  14. abstract def labelShapes: IndexedSeq[DataDesc]

    A list of (name, shape) pairs specifying the label inputs to this module.

    A list of (name, shape) pairs specifying the label inputs to this module. If this module does not accept labels -- either it is a module without loss function, or it is not binded for training, then this should return an empty list [].

  15. abstract def outputNames: IndexedSeq[String]

  16. abstract def outputShapes: IndexedSeq[(String, Shape)]

  17. abstract def update(): Unit

  18. abstract def updateMetric(evalMetric: EvalMetric, labels: IndexedSeq[NDArray]): Unit

    Evaluate and accumulate evaluation metric on outputs of the last forward computation.

    Evaluate and accumulate evaluation metric on outputs of the last forward computation.

    evalMetric
    labels

    Typically DataBatch.label.

Concrete Value Members

  1. final def !=(arg0: Any): Boolean

    Definition Classes
    AnyRef → Any
  2. final def ##(): Int

    Definition Classes
    AnyRef → Any
  3. final def ==(arg0: Any): Boolean

    Definition Classes
    AnyRef → Any
  4. final def asInstanceOf[T0]: T0

    Definition Classes
    Any
  5. def bind(forTraining: Boolean, inputsNeedGrad: Boolean, forceRebind: Boolean, dataShape: DataDesc*): Unit

    Bind the symbols to construct executors.

    Bind the symbols to construct executors. This is necessary before one can perform computation with the module.

    forTraining

    Default is True. Whether the executors should be bind for training.

    inputsNeedGrad

    Default is False. Whether the gradients to the input data need to be computed. Typically this is not needed. But this might be needed when implementing composition of modules.

    forceRebind

    Default is False. This function does nothing if the executors are already binded. But with this True, the executors will be forced to rebind.

    dataShape

    Typically is DataIter.provideData.

    Annotations
    @varargs()
  6. def clone(): AnyRef

    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  7. final def eq(arg0: AnyRef): Boolean

    Definition Classes
    AnyRef
  8. def equals(arg0: Any): Boolean

    Definition Classes
    AnyRef → Any
  9. def finalize(): Unit

    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  10. def fit(trainData: DataIter, evalData: Option[DataIter] = None, numEpoch: Int = 1, fitParams: FitParams = new FitParams): Unit

    Train the module parameters.

    Train the module parameters.

    trainData
    evalData

    If not None, will be used as validation set and evaluate the performance after each epoch.

    numEpoch

    Number of epochs to run training.

    fitParams

    Extra parameters for training.

  11. def forward(dataBatch: DataBatch, isTrain: Boolean): Unit

    Forward computation.

    Forward computation.

    dataBatch

    a batch of data.

    isTrain

    Whether it is for training or not.

  12. def forwardBackward(dataBatch: DataBatch): Unit

  13. final def getClass(): Class[_]

    Definition Classes
    AnyRef → Any
  14. def getSymbol: Symbol

  15. def hashCode(): Int

    Definition Classes
    AnyRef → Any
  16. final def isInstanceOf[T0]: Boolean

    Definition Classes
    Any
  17. def loadParams(fname: String): Unit

    Load model parameters from file.

    Load model parameters from file.

    fname

    Path to input param file.

    Annotations
    @throws( classOf[IOException] )
    Exceptions thrown

    IOException if param file is invalid

  18. final def ne(arg0: AnyRef): Boolean

    Definition Classes
    AnyRef
  19. final def notify(): Unit

    Definition Classes
    AnyRef
  20. final def notifyAll(): Unit

    Definition Classes
    AnyRef
  21. def predict(evalData: DataIter, numBatch: Int = 1, reset: Boolean = true): IndexedSeq[NDArray]

    Run prediction and collect the outputs.

    Run prediction and collect the outputs.

    evalData

    dataIter to do the Inference

    numBatch

    Default is -1, indicating running all the batches in the data iterator.

    reset

    Default is True, indicating whether we should reset the data iter before start doing prediction.

    returns

    The return value will be a list [out1, out2, out3]. The concatenation process will be like

    outputBatches = [
      [a1, a2, a3], // batch a
      [b1, b2, b3]  // batch b
    ]
    result = [
      NDArray, // [a1, b1]
      NDArray, // [a2, b2]
      NDArray, // [a3, b3]
    ]

    Where each element is concatenation of the outputs for all the mini-batches.

  22. def predict(batch: DataBatch): IndexedSeq[NDArray]

  23. def predictEveryBatch(evalData: DataIter, numBatch: Int = 1, reset: Boolean = true): IndexedSeq[IndexedSeq[NDArray]]

    Run prediction and collect the outputs.

    Run prediction and collect the outputs.

    evalData
    numBatch

    Default is -1, indicating running all the batches in the data iterator.

    reset

    Default is True, indicating whether we should reset the data iter before start doing prediction.

    returns

    The return value will be a nested list like [ [out1_batch1, out2_batch1, ...], [out1_batch2, out2_batch2, ...] ] This mode is useful because in some cases (e.g. bucketing), the module does not necessarily produce the same number of outputs.

  24. def saveParams(fname: String): Unit

    Save model parameters to file.

    Save model parameters to file.

    fname

    Path to output param file.

  25. def score(evalData: DataIter, evalMetric: EvalMetric, numBatch: Int = Integer.MAX_VALUE, batchEndCallback: Option[BatchEndCallback] = None, scoreEndCallback: Option[BatchEndCallback] = None, reset: Boolean = true, epoch: Int = 0): EvalMetric

    Run prediction on eval_data and evaluate the performance according to eval_metric.

    Run prediction on eval_data and evaluate the performance according to eval_metric.

    evalData

    : DataIter

    evalMetric

    : EvalMetric

    numBatch

    Number of batches to run. Default is Integer.MAX_VALUE, indicating run until the DataIter finishes.

    batchEndCallback

    Could also be a list of functions.

    reset

    Default True, indicating whether we should reset eval_data before starting evaluating.

    epoch

    Default 0. For compatibility, this will be passed to callbacks (if any). During training, this will correspond to the training epoch number.

  26. def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray], allowMissing: Boolean = false, forceInit: Boolean = true, allowExtra: Boolean = false): Unit

    Assign parameter and aux state values.

    Assign parameter and aux state values. argParams : dict Dictionary of name to value (NDArray) mapping. auxParams : dict Dictionary of name to value (NDArray) mapping. allowMissing : bool If true, params could contain missing values, and the initializer will be called to fill those missing params. forceInit : bool If true, will force re-initialize even if already initialized. allowExtra : bool Whether allow extra parameters that are not needed by symbol. If this is True, no error will be thrown when argParams or auxParams contain extra parameters that is not needed by the executor.

  27. final def synchronized[T0](arg0: ⇒ T0): T0

    Definition Classes
    AnyRef
  28. def toString(): String

    Definition Classes
    AnyRef → Any
  29. final def wait(): Unit

    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  30. final def wait(arg0: Long, arg1: Int): Unit

    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  31. final def wait(arg0: Long): Unit

    Definition Classes
    AnyRef
    Annotations
    @throws( ... )

Inherited from AnyRef

Inherited from Any

Ungrouped