RNN Cell API¶
Warning
This package is currently experimental and may change in the near future.
Overview¶
The rnn
module includes the recurrent neural network (RNN) cell APIs, a suite of tools for building an RNN’s symbolic graph.
Note
The rnn module offers higher-level interface while symbol.RNN is a lower-level interface. The cell APIs in rnn module are easier to use in most cases.
The rnn
module¶
Cell interfaces¶
BaseRNNCell.__call__ |
Unroll the RNN for one time step. |
BaseRNNCell.unroll |
Unroll an RNN cell across time steps. |
BaseRNNCell.reset |
Reset before re-using the cell for another graph. |
BaseRNNCell.begin_state |
Initial state for this cell. |
BaseRNNCell.unpack_weights |
Unpack fused weight matrices into separate weight matrices. |
BaseRNNCell.pack_weights |
Pack separate weight matrices into a single packed weight. |
When working with the cell API, the precise input and output symbols depend on the type of RNN you are using. Take Long Short-Term Memory (LSTM) for example:
import mxnet as mx
# Shape of 'step_data' is (batch_size,).
step_input = mx.symbol.Variable('step_data')
# First we embed our raw input data to be used as LSTM's input.
embedded_step = mx.symbol.Embedding(data=step_input, \
input_dim=input_dim, \
output_dim=embed_dim)
# Then we create an LSTM cell.
lstm_cell = mx.rnn.LSTMCell(num_hidden=50)
# Initialize its hidden and memory states.
# 'begin_state' method takes an initialization function, and uses 'zeros' by default.
begin_state = lstm_cell.begin_state()
The LSTM cell and other non-fused RNN cells are callable. Calling the cell updates it’s state once. This transformation depends on both the current input and the previous states. See this blog post for a great introduction to LSTM and other RNN.
# Call the cell to get the output of one time step for a batch.
output, states = lstm_cell(embedded_step, begin_state)
# 'output' is lstm_t0_out_output of shape (batch_size, hidden_dim).
# 'states' has the recurrent states that will be carried over to the next step,
# which includes both the "hidden state" and the "cell state":
# Both 'lstm_t0_out_output' and 'lstm_t0_state_output' have shape (batch_size, hidden_dim).
Most of the time our goal is to process a sequence of many steps. For this, we need to unroll the LSTM according to the sequence length.
# Embed a sequence. 'seq_data' has the shape of (batch_size, sequence_length).
seq_input = mx.symbol.Variable('seq_data')
embedded_seq = mx.symbol.Embedding(data=seq_input, \
input_dim=input_dim, \
output_dim=embed_dim)
Note
Remember to reset the cell when unrolling/stepping for a new sequence by calling lstm_cell.reset().
# Note that when unrolling, if 'merge_outputs' is set to True, the 'outputs' is merged into a single symbol
# In the layout, 'N' represents batch size, 'T' represents sequence length, and 'C' represents the
# number of dimensions in hidden states.
outputs, states = lstm_cell.unroll(length=sequence_length, \
inputs=embedded_seq, \
layout='NTC', \
merge_outputs=True)
# 'outputs' is concat0_output of shape (batch_size, sequence_length, hidden_dim).
# The hidden state and cell state from the final time step is returned:
# Both 'lstm_t4_out_output' and 'lstm_t4_state_output' have shape (batch_size, hidden_dim).
# If merge_outputs is set to False, a list of symbols for each of the time steps is returned.
outputs, states = lstm_cell.unroll(length=sequence_length, \
inputs=embedded_seq, \
layout='NTC', \
merge_outputs=False)
# In this case, 'outputs' is a list of symbols. Each symbol is of shape (batch_size, hidden_dim).
Note
Loading and saving models that are built with RNN cells API requires using mx.rnn.load_rnn_checkpoint, mx.rnn.save_rnn_checkpoint, and mx.rnn.do_rnn_checkpoint. The list of all the used cells should be provided as the first argument to those functions.
Basic RNN cells¶
rnn
module supports the following RNN cell types.
LSTMCell |
Long-Short Term Memory (LSTM) network cell. |
GRUCell |
Gated Rectified Unit (GRU) network cell. |
RNNCell |
Simple recurrent neural network cell. |
Modifier cells¶
BidirectionalCell |
Bidirectional RNN cell. |
DropoutCell |
Apply dropout on input. |
ZoneoutCell |
Apply Zoneout on base cell. |
ResidualCell |
Adds residual connection as described in Wu et al, 2016 (https://arxiv.org/abs/1609.08144). |
A modifier cell takes in one or more cells and transforms the output of those cells.
BidirectionalCell
is one example. It takes two cells for forward unroll and backward unroll
respectively. After unrolling, the outputs of the forward and backward pass are concatenated.
# Bidirectional cell takes two RNN cells, for forward and backward pass respectively.
# Having different types of cells for forward and backward unrolling is allowed.
bi_cell = mx.rnn.BidirectionalCell(
mx.rnn.LSTMCell(num_hidden=50),
mx.rnn.GRUCell(num_hidden=75))
outputs, states = bi_cell.unroll(length=sequence_length, \
inputs=embedded_seq, \
merge_outputs=True)
# The output feature is the concatenation of the forward and backward pass.
# Thus, the number of output dimensions is the sum of the dimensions of the two cells.
# 'outputs' is the symbol 'bi_out_output' of shape (batch_size, sequence_length, 125L)
# The states of the BidirectionalCell is a list of two lists, corresponding to the
# states of the forward and backward cells respectively.
Note
BidirectionalCell cannot be called or stepped, because the backward unroll requires the output of future steps, and thus the whole sequence is required.
Dropout and zoneout are popular regularization techniques that can be applied to RNN. rnn
module provides DropoutCell
and ZoneoutCell
for regularization on the output and recurrent
states of RNN. ZoneoutCell
takes one RNN cell in the constructor, and supports unrolling like
other cells.
zoneout_cell = mx.rnn.ZoneoutCell(lstm_cell, zoneout_states=0.5)
outputs, states = zoneout_cell.unroll(length=sequence_length, \
inputs=embedded_seq, \
merge_outputs=True)
DropoutCell
performs dropout on the input sequence. It can be used in a stacked
multi-layer RNN setting, which we will cover next.
Residual connection is a useful technique for training deep neural models because it helps the
propagation of gradients by shortening the paths. ResidualCell
provides such functionality for
RNN models.
residual_cell = mx.rnn.ResidualCell(lstm_cell)
outputs, states = residual_cell.unroll(length=sequence_length, \
inputs=embedded_seq, \
merge_outputs=True)
The outputs
are the element-wise sum of both the input and the output of the LSTM cell.
Multi-layer cells¶
SequentialRNNCell |
Sequantially stacking multiple RNN cells. |
SequentialRNNCell.add |
Append a cell into the stack. |
The SequentialRNNCell
allows stacking multiple layers of RNN cells to improve the expressiveness
and performance of the model. Cells can be added to a SequentialRNNCell
in order, from bottom to
top. When unrolling, the output of a lower-level cell is automatically passed to the cell above.
stacked_rnn_cells = mx.rnn.SequentialRNNCell()
stacked_rnn_cells.add(mx.rnn.BidirectionalCell(
mx.rnn.LSTMCell(num_hidden=50),
mx.rnn.LSTMCell(num_hidden=50)))
# Dropout the output of the bottom layer BidirectionalCell with a retention probability of 0.5.
stacked_rnn_cells.add(mx.rnn.DropoutCell(0.5))
stacked_rnn_cells.add(mx.rnn.LSTMCell(num_hidden=50))
outputs, states = stacked_rnn_cells.unroll(length=sequence_length, \
inputs=embedded_seq, \
merge_outputs=True)
# The output of SequentialRNNCell is the same as that of the last layer.
# In this case 'outputs' is the symbol 'concat6_output' of shape (batch_size, sequence_length, hidden_dim)
# The states of the SequentialRNNCell is a list of lists, with each list
# corresponding to the states of each of the added cells respectively.
Fused RNN cell¶
FusedRNNCell |
Fusing RNN layers across time step into one kernel. |
FusedRNNCell.unfuse |
Unfuse the fused RNN in to a stack of rnn cells. |
The computation of an RNN for an input sequence consists of many GEMM and point-wise operations with temporal dependencies dependencies. This could make the computation memory-bound especially on GPU, resulting in longer wall-time. By combining the computation of many small matrices into that of larger ones and streaming the computation whenever possible, the ratio of computation to memory I/O can be increased, which results in better performance on GPU. Such optimization technique is called “fusing”. This post talks in greater detail.
The rnn
module includes a FusedRNNCell
, which provides the optimized fused implementation.
The FusedRNNCell supports bidirectional RNNs and dropout.
fused_lstm_cell = mx.rnn.FusedRNNCell(num_hidden=50, \
num_layers=3, \
mode='lstm', \
bidirectional=True, \
dropout=0.5)
outputs, _ = fused_lstm_cell.unroll(length=sequence_length, \
inputs=embedded_seq, \
merge_outputs=True)
# The 'outputs' is the symbol 'lstm_rnn_output' that has the shape
# (batch_size, sequence_length, forward_backward_concat_dim)
Note
FusedRNNCell supports GPU-only. It cannot be called or stepped.
Note
When dropout is set to non-zero in FusedRNNCell, the dropout is applied to the output of all layers except the last layer. If there is only one layer in the FusedRNNCell, the dropout rate is ignored.
Note
Similar to BidirectionalCell, when bidirectional flag is set to True, the output of FusedRNNCell is twice the size specified by num_hidden.
When training a deep, complex model on multiple GPUs it’s recommended to stack fused RNN cells (one layer per cell) together instead of one with all layers. The reason is that fused RNN cells don’t set gradients to be ready until the computation for the entire layer is completed. Breaking a multi-layer fused RNN cell into several one-layer ones allows gradients to be processed ealier. This reduces communication overhead, especially with multiple GPUs.
The unfuse()
method can be used to convert the FusedRNNCell
into an equivalent
and CPU-compatible SequentialRNNCell
that mirrors the settings of the FusedRNNCell
.
unfused_lstm_cell = fused_lstm_cell.unfuse()
unfused_outputs, _ = unfused_lstm_cell.unroll(length=sequence_length, \
inputs=embedded_seq, \
merge_outputs=True)
# The 'outputs' is the symbol 'lstm_bi_l2_out_output' that has the shape
# (batch_size, sequence_length, forward_backward_concat_dim)
RNN checkpoint methods and parameters¶
save_rnn_checkpoint |
Save checkpoint for model using RNN cells. |
load_rnn_checkpoint |
Load model checkpoint from file. |
do_rnn_checkpoint |
Make a callback to checkpoint Module to prefix every epoch. |
RNNParams |
Container for holding variables. |
RNNParams.get |
Get the variable given a name if one exists or create a new one if missing. |
The model parameters from the training with fused cell can be used for inference with unfused cell,
and vice versa. As the parameters of fused and unfused cells are organized differently, they need to
be converted first. FusedRNNCell
‘s parameters are merged and flattened. In the fused example above,
the mode has lstm_parameters
of shape (total_num_params,)
, whereas the
equivalent SequentialRNNCell’s parameters are separate:
'lstm_l0_i2h_weight': (out_dim, embed_dim)
'lstm_l0_i2h_bias': (out_dim,)
'lstm_l0_h2h_weight': (out_dim, hidden_dim)
'lstm_l0_h2h_bias': (out_dim,)
'lstm_r0_i2h_weight': (out_dim, embed_dim)
...
All cells in the rnn
module support the method unpack_weights()
for converting FusedRNNCell
parameters to the unfused format and pack_weights()
for fusing the parameters. The RNN-specific
checkpointing methods (load_rnn_checkpoint, save_rnn_checkpoint, do_rnn_checkpoint
) handle the
conversion transparently based on the provided cells.
I/O utilities¶
BucketSentenceIter |
Simple bucketing iterator for language model. |
encode_sentences |
Encode sentences and (optionally) build a mapping from string tokens to integer indices. |
API Reference¶
-
class
mxnet.rnn.
BaseRNNCell
(prefix='', params=None)[source]¶ Abstract base class for RNN cells
Parameters: - prefix (str, optional) – Prefix for names of layers (this prefix is also used for names of weights if params is None i.e. if params are being created and not reused)
- params (RNNParams, default None.) – Container for weight sharing between cells. A new RNNParams container is created if params is None.
-
__call__
(inputs, states)[source]¶ Unroll the RNN for one time step.
Parameters: - inputs (sym.Variable) – input symbol, 2D, batch * num_units
- states (list of sym.Variable) – RNN state from previous step or the output of begin_state().
Returns: - output (Symbol) – Symbol corresponding to the output from the RNN when unrolling for a single time step.
- states (nested list of Symbol) – The new state of this RNN after this unrolling. The type of this symbol is same as the output of begin_state(). This can be used as input state to the next time step of this RNN.
See also
begin_state()
- This function can provide the states for the first time step.
unroll()
- This function unrolls an RNN for a given number of (>=1) time steps.
-
params
¶ Parameters of this cell
-
state_info
¶ shape and layout information of states
-
state_shape
¶ shape(s) of states
-
begin_state
(func=, **kwargs)[source]¶ Initial state for this cell.
Parameters: - func (callable, default symbol.zeros) – Function for creating initial state. Can be symbol.zeros, symbol.uniform, symbol.Variable etc. Use symbol.Variable if you want to directly feed input as states.
- **kwargs – more keyword arguments passed to func. For example mean, std, dtype, etc.
Returns: states – Starting states for the first RNN step.
Return type: nested list of Symbol
-
unpack_weights
(args)[source]¶ Unpack fused weight matrices into separate weight matrices.
For example, say you use a module object mod to run a network that has an lstm cell. In mod.get_params()[0], the lstm parameters are all represented as a single big vector. cell.unpack_weights(mod.get_params()[0]) will unpack this vector into a dictionary of more readable lstm parameters - c, f, i, o gates for i2h (input to hidden) and h2h (hidden to hidden) weights.
Parameters: args (dict of str -> NDArray) – Dictionary containing packed weights. usually from Module.get_params()[0]. Returns: args – Dictionary with unpacked weights associated with this cell. Return type: dict of str -> NDArray See also
pack_weights()
- Performs the reverse operation of this function.
-
pack_weights
(args)[source]¶ Pack separate weight matrices into a single packed weight.
Parameters: args (dict of str -> NDArray) – Dictionary containing unpacked weights. Returns: args – Dictionary with packed weights associated with this cell. Return type: dict of str -> NDArray
-
unroll
(length, inputs, begin_state=None, layout='NTC', merge_outputs=None)[source]¶ Unroll an RNN cell across time steps.
Parameters: - length (int) – Number of steps to unroll.
- inputs (Symbol, list of Symbol, or None) –
If inputs is a single Symbol (usually the output of Embedding symbol), it should have shape (batch_size, length, ...) if layout == ‘NTC’, or (length, batch_size, ...) if layout == ‘TNC’.
If inputs is a list of symbols (usually output of previous unroll), they should all have shape (batch_size, ...).
- begin_state (nested list of Symbol, default None) – Input states created by begin_state() or output state of another cell. Created from begin_state() if None.
- layout (str, optional) – layout of input symbol. Only used if inputs is a single Symbol.
- merge_outputs (bool, optional) – If False, return outputs as a list of Symbols. If True, concatenate output across time steps and return a single symbol with shape (batch_size, length, ...) if layout == ‘NTC’, or (length, batch_size, ...) if layout == ‘TNC’. If None, output whatever is faster.
Returns: - outputs (list of Symbol or Symbol) – Symbol (if merge_outputs is True) or list of Symbols (if merge_outputs is False) corresponding to the output from the RNN from this unrolling.
- states (nested list of Symbol) – The new state of this RNN after this unrolling. The type of this symbol is same as the output of begin_state().
-
class
mxnet.rnn.
LSTMCell
(num_hidden, prefix='lstm_', params=None, forget_bias=1.0)[source]¶ Long-Short Term Memory (LSTM) network cell.
Parameters: - num_hidden (int) – Number of units in output symbol.
- prefix (str, default ‘lstm_‘) – Prefix for name of layers (and name of weight if params is None).
- params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
- forget_bias (bias added to forget gate, default 1.0.) – Jozefowicz et al. 2015 recommends setting this to 1.0
-
class
mxnet.rnn.
GRUCell
(num_hidden, prefix='gru_', params=None)[source]¶ Gated Rectified Unit (GRU) network cell. Note: this is an implementation of the cuDNN version of GRUs (slight modification compared to Cho et al. 2014).
Parameters:
-
class
mxnet.rnn.
RNNCell
(num_hidden, activation='tanh', prefix='rnn_', params=None)[source]¶ Simple recurrent neural network cell.
Parameters: - num_hidden (int) – Number of units in output symbol.
- activation (str or Symbol, default 'tanh') – Type of activation function. Options are ‘relu’ and ‘tanh’.
- prefix (str, default ‘rnn_‘) – Prefix for name of layers (and name of weight if params is None).
- params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
-
class
mxnet.rnn.
FusedRNNCell
(num_hidden, num_layers=1, mode='lstm', bidirectional=False, dropout=0.0, get_next_state=False, forget_bias=1.0, prefix=None, params=None)[source]¶ Fusing RNN layers across time step into one kernel. Improves speed but is less flexible. Currently only supported if using cuDNN on GPU.
Parameters: - num_hidden (int) – Number of units in output symbol.
- num_layers (int, default 1) – Number of layers in the cell.
- mode (str, default 'lstm') – Type of RNN. options are ‘rnn_relu’, ‘rnn_tanh’, ‘lstm’, ‘gru’.
- bidirectional (bool, default False) – Whether to use bidirectional unroll. The output dimension size is doubled if bidrectional.
- dropout (float, default 0.) – Fraction of the input that gets dropped out during training time.
- get_next_state (bool, default False) – Whether to return the states that can be used as starting states next time.
- forget_bias (bias added to forget gate, default 1.0.) – Jozefowicz et al. 2015 recommends setting this to 1.0
- prefix (str, default ‘$mode_’ such as ‘lstm_‘) – Prefix for names of layers (this prefix is also used for names of weights if params is None i.e. if params are being created and not reused)
- params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
-
unfuse
()[source]¶ Unfuse the fused RNN in to a stack of rnn cells.
Returns: cell – unfused cell that can be used for stepping, and can run on CPU. Return type: SequentialRNNCell
-
class
mxnet.rnn.
SequentialRNNCell
(params=None)[source]¶ Sequantially stacking multiple RNN cells.
Parameters: params (RNNParams, default None) – Container for weight sharing between cells. Created if None. -
add
(cell)[source]¶ Append a cell into the stack.
Parameters: cell (BaseRNNCell) – The cell to be appended. During unroll, previous cell’s output (or raw inputs if no previous cell) is used as the input to this cell.
-
-
class
mxnet.rnn.
BidirectionalCell
(l_cell, r_cell, params=None, output_prefix='bi_')[source]¶ Bidirectional RNN cell.
Parameters: - l_cell (BaseRNNCell) – cell for forward unrolling
- r_cell (BaseRNNCell) – cell for backward unrolling
- params (RNNParams, default None.) – Container for weight sharing between cells. A new RNNParams container is created if params is None.
- output_prefix (str, default ‘bi_‘) – prefix for name of output
-
class
mxnet.rnn.
DropoutCell
(dropout, prefix='dropout_', params=None)[source]¶ Apply dropout on input.
Parameters: - dropout (float) – Percentage of elements to drop out, which is 1 - percentage to retain.
- prefix (str, default ‘dropout_‘) – Prefix for names of layers (this prefix is also used for names of weights if params is None i.e. if params are being created and not reused)
- params (RNNParams, default None) – Container for weight sharing between cells. Created if None.
-
class
mxnet.rnn.
ZoneoutCell
(base_cell, zoneout_outputs=0.0, zoneout_states=0.0)[source]¶ Apply Zoneout on base cell.
Parameters: - base_cell (BaseRNNCell) – Cell on whose states to perform zoneout.
- zoneout_outputs (float, default 0.) – Fraction of the output that gets dropped out during training time.
- zoneout_states (float, default 0.) – Fraction of the states that gets dropped out during training time.
-
class
mxnet.rnn.
ResidualCell
(base_cell)[source]¶ Adds residual connection as described in Wu et al, 2016 (https://arxiv.org/abs/1609.08144).
Output of the cell is output of the base cell plus input.
Parameters: base_cell (BaseRNNCell) – Cell on whose outputs to add residual connection.
-
class
mxnet.rnn.
RNNParams
(prefix='')[source]¶ Container for holding variables. Used by RNN cells for parameter sharing between cells.
Parameters: prefix (str) – Names of all variables created by this container will be prepended with prefix.
-
class
mxnet.rnn.
BucketSentenceIter
(sentences, batch_size, buckets=None, invalid_label=-1, data_name='data', label_name='softmax_label', dtype='float32', layout='NT')[source]¶ Simple bucketing iterator for language model. The label at each sequence step is the following token in the sequence.
Parameters: - sentences (list of list of int) – Encoded sentences.
- batch_size (int) – Batch size of the data.
- invalid_label (int, optional) – Key for invalid label, e.g.
. The default is -1. - dtype (str, optional) – Data type of the encoding. The default data type is ‘float32’.
- buckets (list of int, optional) – Size of the data buckets. Automatically generated if None.
- data_name (str, optional) – Name of the data. The default name is ‘data’.
- label_name (str, optional) – Name of the label. The default name is ‘softmax_label’.
- layout (str, optional) – Format of data and label. ‘NT’ means (batch_size, length) and ‘TN’ means (length, batch_size).
-
rnn.
encode_sentences
(sentences, vocab=None, invalid_label=-1, invalid_key='\n', start_label=0)¶ Encode sentences and (optionally) build a mapping from string tokens to integer indices. Unknown keys will be added to vocabulary.
Parameters: - sentences (list of list of str) – A list of sentences to encode. Each sentence should be a list of string tokens.
- vocab (None or dict of str -> int) – Optional input Vocabulary
- invalid_label (int, default -1) – Index for invalid token, like
- invalid_key (str, default 'n') – Key for invalid token. Use ‘n’ for end of sentence by default.
- start_label (int) – lowest index.
Returns: - result (list of list of int) – encoded sentences
- vocab (dict of str -> int) – result vocabulary
-
rnn.
save_rnn_checkpoint
(cells, prefix, epoch, symbol, arg_params, aux_params)¶ Save checkpoint for model using RNN cells. Unpacks weight before saving.
Parameters: - cells (RNNCell or list of RNNCells) – The RNN cells used by this symbol.
- prefix (str) – Prefix of model name.
- epoch (int) – The epoch number of the model.
- symbol (Symbol) – The input symbol
- arg_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s weights.
- aux_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s auxiliary states.
Notes
prefix-symbol.json
will be saved for symbol.prefix-epoch.params
will be saved for parameters.
-
rnn.
load_rnn_checkpoint
(cells, prefix, epoch)¶ Load model checkpoint from file. Pack weights after loading.
Parameters: - cells (RNNCell or list of RNNCells) – The RNN cells used by this symbol.
- prefix (str) – Prefix of model name.
- epoch (int) – Epoch number of model we would like to load.
Returns: - symbol (Symbol) – The symbol configuration of computation network.
- arg_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s weights.
- aux_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s auxiliary states.
Notes
- symbol will be loaded from
prefix-symbol.json
. - parameters will be loaded from
prefix-epoch.params
.
-
rnn.
do_rnn_checkpoint
(cells, prefix, period=1)¶ Make a callback to checkpoint Module to prefix every epoch. unpacks weights used by cells before saving.
Parameters: - cells (RNNCell or list of RNNCells) – The RNN cells used by this symbol.
- prefix (str) – The file prefix to checkpoint to
- period (int) – How many epochs to wait before checkpointing. Default is 1.
Returns: callback – The callback function that can be passed as iter_end_callback to fit.
Return type: function