# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=no-member, invalid-name, protected-access, no-self-use
# pylint: disable=too-many-branches, too-many-arguments, no-self-use
# pylint: disable=too-many-lines, arguments-differ
"""Definition of various recurrent neural network cells."""
__all__ = ['RecurrentCell', 'HybridRecurrentCell',
           'RNNCell', 'LSTMCell', 'GRUCell',
           'SequentialRNNCell', 'DropoutCell',
           'ModifierCell', 'ZoneoutCell', 'ResidualCell',
           'BidirectionalCell']
from ... import symbol, ndarray
from ...base import string_types, numeric_types, _as_list
from ..block import Block, HybridBlock
from ..utils import _indent
from .. import tensor_types
from ..nn import LeakyReLU
def _cells_state_info(cells, batch_size):
    return sum([c.state_info(batch_size) for c in cells], [])
def _cells_begin_state(cells, **kwargs):
    return sum([c.begin_state(**kwargs) for c in cells], [])
def _get_begin_state(cell, F, begin_state, inputs, batch_size):
    if begin_state is None:
        if F is ndarray:
            ctx = inputs.context if isinstance(inputs, tensor_types) else inputs[0].context
            with ctx:
                begin_state = cell.begin_state(func=F.zeros, batch_size=batch_size)
        else:
            begin_state = cell.begin_state(func=F.zeros, batch_size=batch_size)
    return begin_state
def _format_sequence(length, inputs, layout, merge, in_layout=None):
    assert inputs is not None, \
        "unroll(inputs=None) has been deprecated. " \
        "Please create input variables outside unroll."
    axis = layout.find('T')
    batch_axis = layout.find('N')
    batch_size = 0
    in_axis = in_layout.find('T') if in_layout is not None else axis
    if isinstance(inputs, symbol.Symbol):
        F = symbol
        if merge is False:
            assert len(inputs.list_outputs()) == 1, \
                "unroll doesn't allow grouped symbol as input. Please convert " \
                "to list with list(inputs) first or let unroll handle splitting."
            inputs = list(symbol.split(inputs, axis=in_axis, num_outputs=length,
                                       squeeze_axis=1))
    elif isinstance(inputs, ndarray.NDArray):
        F = ndarray
        batch_size = inputs.shape[batch_axis]
        if merge is False:
            assert length is None or length == inputs.shape[in_axis]
            inputs = _as_list(ndarray.split(inputs, axis=in_axis,
                                            num_outputs=inputs.shape[in_axis],
                                            squeeze_axis=1))
    else:
        assert length is None or len(inputs) == length
        if isinstance(inputs[0], symbol.Symbol):
            F = symbol
        else:
            F = ndarray
            batch_size = inputs[0].shape[batch_axis]
        if merge is True:
            inputs = [F.expand_dims(i, axis=axis) for i in inputs]
            inputs = F.concat(*inputs, dim=axis)
            in_axis = axis
    if isinstance(inputs, tensor_types) and axis != in_axis:
        inputs = F.swapaxes(inputs, dim1=axis, dim2=in_axis)
    return inputs, axis, F, batch_size
[docs]class RecurrentCell(Block):
    """Abstract base class for RNN cells
    Parameters
    ----------
    prefix : str, optional
        Prefix for names of `Block`s
        (this prefix is also used for names of weights if `params` is `None`
        i.e. if `params` are being created and not reused)
    params : Parameter or None, optional
        Container for weight sharing between cells.
        A new Parameter container is created if `params` is `None`.
    """
    def __init__(self, prefix=None, params=None):
        super(RecurrentCell, self).__init__(prefix=prefix, params=params)
        self._modified = False
        self.reset()
[docs]    def reset(self):
        """Reset before re-using the cell for another graph."""
        self._init_counter = -1
        self._counter = -1
        for cell in self._children:
            cell.reset() 
[docs]    def state_info(self, batch_size=0):
        """shape and layout information of states"""
        raise NotImplementedError() 
[docs]    def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
        """Initial state for this cell.
        Parameters
        ----------
        func : callable, default symbol.zeros
            Function for creating initial state.
            For Symbol API, func can be `symbol.zeros`, `symbol.uniform`,
            `symbol.var etc`. Use `symbol.var` if you want to directly
            feed input as states.
            For NDArray API, func can be `ndarray.zeros`, `ndarray.ones`, etc.
        batch_size: int, default 0
            Only required for NDArray API. Size of the batch ('N' in layout)
            dimension of input.
        **kwargs :
            Additional keyword arguments passed to func. For example
            `mean`, `std`, `dtype`, etc.
        Returns
        -------
        states : nested list of Symbol
            Starting states for the first RNN step.
        """
        assert not self._modified, \
            
"After applying modifier cells (e.g. ZoneoutCell) the base " \
            
"cell cannot be called directly. Call the modifier cell instead."
        states = []
        for info in self.state_info(batch_size):
            self._init_counter += 1
            if info is not None:
                info.update(kwargs)
            else:
                info = kwargs
            state = func(name='%sbegin_state_%d'%(self._prefix, self._init_counter),
                         **info)
            states.append(state)
        return states 
[docs]    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
        """Unrolls an RNN cell across time steps.
        Parameters
        ----------
        length : int
            Number of steps to unroll.
        inputs : Symbol, list of Symbol, or None
            If `inputs` is a single Symbol (usually the output
            of Embedding symbol), it should have shape
            (batch_size, length, ...) if `layout` is 'NTC',
            or (length, batch_size, ...) if `layout` is 'TNC'.
            If `inputs` is a list of symbols (usually output of
            previous unroll), they should all have shape
            (batch_size, ...).
        begin_state : nested list of Symbol, optional
            Input states created by `begin_state()`
            or output state of another cell.
            Created from `begin_state()` if `None`.
        layout : str, optional
            `layout` of input symbol. Only used if inputs
            is a single Symbol.
        merge_outputs : bool, optional
            If `False`, returns outputs as a list of Symbols.
            If `True`, concatenates output across time steps
            and returns a single symbol with shape
            (batch_size, length, ...) if layout is 'NTC',
            or (length, batch_size, ...) if layout is 'TNC'.
            If `None`, output whatever is faster.
        Returns
        -------
        outputs : list of Symbol or Symbol
            Symbol (if `merge_outputs` is True) or list of Symbols
            (if `merge_outputs` is False) corresponding to the output from
            the RNN from this unrolling.
        states : list of Symbol
            The new state of this RNN after this unrolling.
            The type of this symbol is same as the output of `begin_state()`.
        """
        self.reset()
        inputs, _, F, batch_size = _format_sequence(length, inputs, layout, False)
        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
        states = begin_state
        outputs = []
        for i in range(length):
            output, states = self(inputs[i], states)
            outputs.append(output)
        outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs)
        return outputs, states 
    #pylint: disable=no-self-use
    def _get_activation(self, F, inputs, activation, **kwargs):
        """Get activation function. Convert if is string"""
        if isinstance(activation, string_types):
            return F.Activation(inputs, act_type=activation, **kwargs)
        elif isinstance(activation, LeakyReLU):
            return F.LeakyReLU(inputs, act_type='leaky', slope=activation._alpha, **kwargs)
        else:
            return activation(inputs, **kwargs)
[docs]    def forward(self, inputs, states):
        """Unrolls the recurrent cell for one time step.
        Parameters
        ----------
        inputs : sym.Variable
            Input symbol, 2D, of shape (batch_size * num_units).
        states : list of sym.Variable
            RNN state from previous step or the output of begin_state().
        Returns
        -------
        output : Symbol
            Symbol corresponding to the output from the RNN when unrolling
            for a single time step.
        states : list of Symbol
            The new state of this RNN after this unrolling.
            The type of this symbol is same as the output of `begin_state()`.
            This can be used as an input state to the next time step
            of this RNN.
        See Also
        --------
        begin_state: This function can provide the states for the first time step.
        unroll: This function unrolls an RNN for a given number of (>=1) time steps.
        """
        # pylint: disable= arguments-differ
        self._counter += 1
        return super(RecurrentCell, self).forward(inputs, states)  
[docs]class HybridRecurrentCell(RecurrentCell, HybridBlock):
    """HybridRecurrentCell supports hybridize."""
    def __init__(self, prefix=None, params=None):
        super(HybridRecurrentCell, self).__init__(prefix=prefix, params=params)
    def hybrid_forward(self, F, x, *args, **kwargs):
        raise NotImplementedError 
[docs]class RNNCell(HybridRecurrentCell):
    r"""Elman RNN recurrent neural network cell.
    Each call computes the following function:
    .. math::
        h_t = \tanh(w_{ih} * x_t + b_{ih}  +  w_{hh} * h_{(t-1)} + b_{hh})
    where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is the hidden
    state of the previous layer at time `t` or :math:`input_t` for the first layer.
    If nonlinearity='relu', then `ReLU` is used instead of `tanh`.
    Parameters
    ----------
    hidden_size : int
        Number of units in output symbol
    activation : str or Symbol, default 'tanh'
        Type of activation function.
    i2h_weight_initializer : str or Initializer
        Initializer for the input weights matrix, used for the linear
        transformation of the inputs.
    h2h_weight_initializer : str or Initializer
        Initializer for the recurrent weights matrix, used for the linear
        transformation of the recurrent state.
    i2h_bias_initializer : str or Initializer
        Initializer for the bias vector.
    h2h_bias_initializer : str or Initializer
        Initializer for the bias vector.
    prefix : str, default 'rnn_'
        Prefix for name of `Block`s
        (and name of weight if params is `None`).
    params : Parameter or None
        Container for weight sharing between cells.
        Created if `None`.
    Inputs:
        - **data**: input tensor with shape `(batch_size, input_size)`.
        - **states**: a list of one initial recurrent state tensor with shape
          `(batch_size, num_hidden)`.
    Outputs:
        - **out**: output tensor with shape `(batch_size, num_hidden)`.
        - **next_states**: a list of one output recurrent state tensor with the
          same shape as `states`.
    """
    def __init__(self, hidden_size, activation='tanh',
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 input_size=0, prefix=None, params=None):
        super(RNNCell, self).__init__(prefix=prefix, params=params)
        self._hidden_size = hidden_size
        self._activation = activation
        self._input_size = input_size
        self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
                                          init=i2h_weight_initializer,
                                          allow_deferred_init=True)
        self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
                                          init=h2h_weight_initializer,
                                          allow_deferred_init=True)
        self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
                                        init=i2h_bias_initializer,
                                        allow_deferred_init=True)
        self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
                                        init=h2h_bias_initializer,
                                        allow_deferred_init=True)
    def state_info(self, batch_size=0):
        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
    def _alias(self):
        return 'rnn'
    def __repr__(self):
        s = '{name}({mapping}'
        if hasattr(self, '_activation'):
            s += ', {_activation}'
        s += ')'
        shape = self.i2h_weight.shape
        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
        return s.format(name=self.__class__.__name__,
                        mapping=mapping,
                        **self.__dict__)
    def hybrid_forward(self, F, inputs, states, i2h_weight,
                       h2h_weight, i2h_bias, h2h_bias):
        prefix = 't%d_'%self._counter
        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
                               num_hidden=self._hidden_size,
                               name=prefix+'i2h')
        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
                               num_hidden=self._hidden_size,
                               name=prefix+'h2h')
        output = self._get_activation(F, i2h + h2h, self._activation,
                                      name=prefix+'out')
        return output, [output] 
[docs]class LSTMCell(HybridRecurrentCell):
    r"""Long-Short Term Memory (LSTM) network cell.
    Each call computes the following function:
    .. math::
        \begin{array}{ll}
        i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
        f_t = sigmoid(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
        o_t = sigmoid(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
        c_t = f_t * c_{(t-1)} + i_t * g_t \\
        h_t = o_t * \tanh(c_t)
        \end{array}
    where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the
    cell state at time `t`, :math:`x_t` is the hidden state of the previous
    layer at time `t` or :math:`input_t` for the first layer, and :math:`i_t`,
    :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and
    out gates, respectively.
    Parameters
    ----------
    hidden_size : int
        Number of units in output symbol.
    i2h_weight_initializer : str or Initializer
        Initializer for the input weights matrix, used for the linear
        transformation of the inputs.
    h2h_weight_initializer : str or Initializer
        Initializer for the recurrent weights matrix, used for the linear
        transformation of the recurrent state.
    i2h_bias_initializer : str or Initializer, default 'lstmbias'
        Initializer for the bias vector. By default, bias for the forget
        gate is initialized to 1 while all other biases are initialized
        to zero.
    h2h_bias_initializer : str or Initializer
        Initializer for the bias vector.
    prefix : str, default 'lstm_'
        Prefix for name of `Block`s
        (and name of weight if params is `None`).
    params : Parameter or None
        Container for weight sharing between cells.
        Created if `None`.
    Inputs:
        - **data**: input tensor with shape `(batch_size, input_size)`.
        - **states**: a list of two initial recurrent state tensors. Each has shape
          `(batch_size, num_hidden)`.
    Outputs:
        - **out**: output tensor with shape `(batch_size, num_hidden)`.
        - **next_states**: a list of two output recurrent state tensors. Each has
          the same shape as `states`.
    """
    def __init__(self, hidden_size,
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 input_size=0, prefix=None, params=None):
        super(LSTMCell, self).__init__(prefix=prefix, params=params)
        self._hidden_size = hidden_size
        self._input_size = input_size
        self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
                                          init=i2h_weight_initializer,
                                          allow_deferred_init=True)
        self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
                                          init=h2h_weight_initializer,
                                          allow_deferred_init=True)
        self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
                                        init=i2h_bias_initializer,
                                        allow_deferred_init=True)
        self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
                                        init=h2h_bias_initializer,
                                        allow_deferred_init=True)
    def state_info(self, batch_size=0):
        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'},
                {'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
    def _alias(self):
        return 'lstm'
    def __repr__(self):
        s = '{name}({mapping})'
        shape = self.i2h_weight.shape
        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
        return s.format(name=self.__class__.__name__,
                        mapping=mapping,
                        **self.__dict__)
    def hybrid_forward(self, F, inputs, states, i2h_weight,
                       h2h_weight, i2h_bias, h2h_bias):
        prefix = 't%d_'%self._counter
        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
                               num_hidden=self._hidden_size*4, name=prefix+'i2h')
        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
                               num_hidden=self._hidden_size*4, name=prefix+'h2h')
        gates = i2h + h2h
        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
        in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
        forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
        in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
        out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
                                   name=prefix+'state')
        next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
                                  name=prefix+'out')
        return next_h, [next_h, next_c] 
[docs]class GRUCell(HybridRecurrentCell):
    r"""Gated Rectified Unit (GRU) network cell.
    Note: this is an implementation of the cuDNN version of GRUs
    (slight modification compared to Cho et al. 2014).
    Each call computes the following function:
    .. math::
        \begin{array}{ll}
        r_t = sigmoid(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
        i_t = sigmoid(W_{ii} x_t + b_{ii} + W_hi h_{(t-1)} + b_{hi}) \\
        n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
        h_t = (1 - i_t) * n_t + i_t * h_{(t-1)} \\
        \end{array}
    where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden
    state of the previous layer at time `t` or :math:`input_t` for the first layer,
    and :math:`r_t`, :math:`i_t`, :math:`n_t` are the reset, input, and new gates, respectively.
    Parameters
    ----------
    hidden_size : int
        Number of units in output symbol.
    i2h_weight_initializer : str or Initializer
        Initializer for the input weights matrix, used for the linear
        transformation of the inputs.
    h2h_weight_initializer : str or Initializer
        Initializer for the recurrent weights matrix, used for the linear
        transformation of the recurrent state.
    i2h_bias_initializer : str or Initializer
        Initializer for the bias vector.
    h2h_bias_initializer : str or Initializer
        Initializer for the bias vector.
    prefix : str, default 'gru_'
        prefix for name of `Block`s
        (and name of weight if params is `None`).
    params : Parameter or None
        Container for weight sharing between cells.
        Created if `None`.
    Inputs:
        - **data**: input tensor with shape `(batch_size, input_size)`.
        - **states**: a list of one initial recurrent state tensor with shape
          `(batch_size, num_hidden)`.
    Outputs:
        - **out**: output tensor with shape `(batch_size, num_hidden)`.
        - **next_states**: a list of one output recurrent state tensor with the
          same shape as `states`.
    """
    def __init__(self, hidden_size,
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 input_size=0, prefix=None, params=None):
        super(GRUCell, self).__init__(prefix=prefix, params=params)
        self._hidden_size = hidden_size
        self._input_size = input_size
        self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
                                          init=i2h_weight_initializer,
                                          allow_deferred_init=True)
        self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
                                          init=h2h_weight_initializer,
                                          allow_deferred_init=True)
        self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
                                        init=i2h_bias_initializer,
                                        allow_deferred_init=True)
        self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
                                        init=h2h_bias_initializer,
                                        allow_deferred_init=True)
    def state_info(self, batch_size=0):
        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
    def _alias(self):
        return 'gru'
    def __repr__(self):
        s = '{name}({mapping})'
        shape = self.i2h_weight.shape
        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
        return s.format(name=self.__class__.__name__,
                        mapping=mapping,
                        **self.__dict__)
    def hybrid_forward(self, F, inputs, states, i2h_weight,
                       h2h_weight, i2h_bias, h2h_bias):
        # pylint: disable=too-many-locals
        prefix = 't%d_'%self._counter
        prev_state_h = states[0]
        i2h = F.FullyConnected(data=inputs,
                               weight=i2h_weight,
                               bias=i2h_bias,
                               num_hidden=self._hidden_size * 3,
                               name=prefix+'i2h')
        h2h = F.FullyConnected(data=prev_state_h,
                               weight=h2h_weight,
                               bias=h2h_bias,
                               num_hidden=self._hidden_size * 3,
                               name=prefix+'h2h')
        i2h_r, i2h_z, i2h = F.SliceChannel(i2h, num_outputs=3,
                                           name=prefix+'i2h_slice')
        h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3,
                                           name=prefix+'h2h_slice')
        reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid",
                                  name=prefix+'r_act')
        update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid",
                                   name=prefix+'z_act')
        next_h_tmp = F.Activation(i2h + reset_gate * h2h, act_type="tanh",
                                  name=prefix+'h_act')
        next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h,
                                   name=prefix+'out')
        return next_h, [next_h] 
[docs]class SequentialRNNCell(RecurrentCell):
    """Sequentially stacking multiple RNN cells."""
    def __init__(self, prefix=None, params=None):
        super(SequentialRNNCell, self).__init__(prefix=prefix, params=params)
    def __repr__(self):
        s = '{name}(\n{modstr}\n)'
        return s.format(name=self.__class__.__name__,
                        modstr='\n'.join(['({i}): {m}'.format(i=i, m=_indent(m.__repr__(), 2))
                                          for i, m in enumerate(self._children)]))
[docs]    def add(self, cell):
        """Appends a cell into the stack.
        Parameters
        ----------
        cell : RecurrentCell
            The cell to add.
        """
        self.register_child(cell) 
    def state_info(self, batch_size=0):
        return _cells_state_info(self._children, batch_size)
    def begin_state(self, **kwargs):
        assert not self._modified, \
            
"After applying modifier cells (e.g. ZoneoutCell) the base " \
            
"cell cannot be called directly. Call the modifier cell instead."
        return _cells_begin_state(self._children, **kwargs)
    def __call__(self, inputs, states):
        self._counter += 1
        next_states = []
        p = 0
        for cell in self._children:
            assert not isinstance(cell, BidirectionalCell)
            n = len(cell.state_info())
            state = states[p:p+n]
            p += n
            inputs, state = cell(inputs, state)
            next_states.append(state)
        return inputs, sum(next_states, [])
    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
        self.reset()
        inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
        num_cells = len(self._children)
        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
        p = 0
        next_states = []
        for i, cell in enumerate(self._children):
            n = len(cell.state_info())
            states = begin_state[p:p+n]
            p += n
            inputs, states = cell.unroll(length, inputs=inputs, begin_state=states, layout=layout,
                                         merge_outputs=None if i < num_cells-1 else merge_outputs)
            next_states.extend(states)
        return inputs, next_states
    def __getitem__(self, i):
        return self._children[i]
    def __len__(self):
        return len(self._children)
    def hybrid_forward(self, *args, **kwargs):
        raise NotImplementedError 
[docs]class DropoutCell(HybridRecurrentCell):
    """Applies dropout on input.
    Parameters
    ----------
    rate : float
        Percentage of elements to drop out, which
        is 1 - percentage to retain.
    Inputs:
        - **data**: input tensor with shape `(batch_size, size)`.
        - **states**: a list of recurrent state tensors.
    Outputs:
        - **out**: output tensor with shape `(batch_size, size)`.
        - **next_states**: returns input `states` directly.
    """
    def __init__(self, rate, prefix=None, params=None):
        super(DropoutCell, self).__init__(prefix, params)
        assert isinstance(rate, numeric_types), "rate must be a number"
        self.rate = rate
    def __repr__(self):
        s = '{name}(rate = {rate})'
        return s.format(name=self.__class__.__name__,
                        **self.__dict__)
    def state_info(self, batch_size=0):
        return []
    def _alias(self):
        return 'dropout'
    def hybrid_forward(self, F, inputs, states):
        if self.rate > 0:
            inputs = F.Dropout(data=inputs, p=self.rate, name='t%d_fwd'%self._counter)
        return inputs, states
    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
        self.reset()
        inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
        if isinstance(inputs, tensor_types):
            return self.hybrid_forward(F, inputs, begin_state if begin_state else [])
        else:
            return super(DropoutCell, self).unroll(
                length, inputs, begin_state=begin_state, layout=layout,
                merge_outputs=merge_outputs) 
[docs]class ModifierCell(HybridRecurrentCell):
    """Base class for modifier cells. A modifier
    cell takes a base cell, apply modifications
    on it (e.g. Zoneout), and returns a new cell.
    After applying modifiers the base cell should
    no longer be called directly. The modifier cell
    should be used instead.
    """
    def __init__(self, base_cell):
        assert not base_cell._modified, \
            
"Cell %s is already modified. One cell cannot be modified twice"%base_cell.name
        base_cell._modified = True
        super(ModifierCell, self).__init__(prefix=base_cell.prefix+self._alias(),
                                           params=None)
        self.base_cell = base_cell
    @property
    def params(self):
        return self.base_cell.params
    def state_info(self, batch_size=0):
        return self.base_cell.state_info(batch_size)
    def begin_state(self, func=symbol.zeros, **kwargs):
        assert not self._modified, \
            
"After applying modifier cells (e.g. DropoutCell) the base " \
            
"cell cannot be called directly. Call the modifier cell instead."
        self.base_cell._modified = False
        begin = self.base_cell.begin_state(func=func, **kwargs)
        self.base_cell._modified = True
        return begin
    def hybrid_forward(self, F, inputs, states):
        raise NotImplementedError
    def __repr__(self):
        s = '{name}({base_cell})'
        return s.format(name=self.__class__.__name__,
                        **self.__dict__) 
[docs]class ZoneoutCell(ModifierCell):
    """Applies Zoneout on base cell."""
    def __init__(self, base_cell, zoneout_outputs=0., zoneout_states=0.):
        assert not isinstance(base_cell, BidirectionalCell), \
            
"BidirectionalCell doesn't support zoneout since it doesn't support step. " \
            
"Please add ZoneoutCell to the cells underneath instead."
        assert not isinstance(base_cell, SequentialRNNCell) or not base_cell._bidirectional, \
            
"Bidirectional SequentialRNNCell doesn't support zoneout. " \
            
"Please add ZoneoutCell to the cells underneath instead."
        super(ZoneoutCell, self).__init__(base_cell)
        self.zoneout_outputs = zoneout_outputs
        self.zoneout_states = zoneout_states
        self._prev_output = None
    def __repr__(self):
        s = '{name}(p_out={zoneout_outputs}, p_state={zoneout_states}, {base_cell})'
        return s.format(name=self.__class__.__name__,
                        **self.__dict__)
    def _alias(self):
        return 'zoneout'
    def reset(self):
        super(ZoneoutCell, self).reset()
        self._prev_output = None
    def hybrid_forward(self, F, inputs, states):
        cell, p_outputs, p_states = self.base_cell, self.zoneout_outputs, self.zoneout_states
        next_output, next_states = cell(inputs, states)
        mask = (lambda p, like: F.Dropout(F.ones_like(like), p=p))
        prev_output = self._prev_output
        if prev_output is None:
            prev_output = F.zeros_like(next_output)
        output = (F.where(mask(p_outputs, next_output), next_output, prev_output)
                  if p_outputs != 0. else next_output)
        states = ([F.where(mask(p_states, new_s), new_s, old_s) for new_s, old_s in
                   zip(next_states, states)] if p_states != 0. else next_states)
        self._prev_output = output
        return output, states 
[docs]class ResidualCell(ModifierCell):
    """
    Adds residual connection as described in Wu et al, 2016
    (https://arxiv.org/abs/1609.08144).
    Output of the cell is output of the base cell plus input.
    """
    def __init__(self, base_cell):
        super(ResidualCell, self).__init__(base_cell)
    def hybrid_forward(self, F, inputs, states):
        output, states = self.base_cell(inputs, states)
        output = F.elemwise_add(output, inputs, name='t%d_fwd'%self._counter)
        return output, states
    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
        self.reset()
        self.base_cell._modified = False
        outputs, states = self.base_cell.unroll(length, inputs=inputs, begin_state=begin_state,
                                                layout=layout, merge_outputs=merge_outputs)
        self.base_cell._modified = True
        merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \
                        
merge_outputs
        inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
        if merge_outputs:
            outputs = F.elemwise_add(outputs, inputs)
        else:
            outputs = [F.elemwise_add(i, j) for i, j in zip(outputs, inputs)]
        return outputs, states 
[docs]class BidirectionalCell(HybridRecurrentCell):
    """Bidirectional RNN cell.
    Parameters
    ----------
    l_cell : RecurrentCell
        Cell for forward unrolling
    r_cell : RecurrentCell
        Cell for backward unrolling
    """
    def __init__(self, l_cell, r_cell, output_prefix='bi_'):
        super(BidirectionalCell, self).__init__(prefix='', params=None)
        self.register_child(l_cell)
        self.register_child(r_cell)
        self._output_prefix = output_prefix
    def __call__(self, inputs, states):
        raise NotImplementedError("Bidirectional cannot be stepped. Please use unroll")
    def __repr__(self):
        s = '{name}(forward={l_cell}, backward={r_cell})'
        return s.format(name=self.__class__.__name__,
                        l_cell=self._children[0],
                        r_cell=self._children[1])
    def state_info(self, batch_size=0):
        return _cells_state_info(self._children, batch_size)
    def begin_state(self, **kwargs):
        assert not self._modified, \
            
"After applying modifier cells (e.g. DropoutCell) the base " \
            
"cell cannot be called directly. Call the modifier cell instead."
        return _cells_begin_state(self._children, **kwargs)
    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
        self.reset()
        inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
        states = begin_state
        l_cell, r_cell = self._children
        l_outputs, l_states = l_cell.unroll(length, inputs=inputs,
                                            begin_state=states[:len(l_cell.state_info(batch_size))],
                                            layout=layout, merge_outputs=merge_outputs)
        r_outputs, r_states = r_cell.unroll(length,
                                            inputs=list(reversed(inputs)),
                                            begin_state=states[len(l_cell.state_info(batch_size)):],
                                            layout=layout, merge_outputs=merge_outputs)
        if merge_outputs is None:
            merge_outputs = (isinstance(l_outputs, tensor_types)
                             and isinstance(r_outputs, tensor_types))
            l_outputs, _, _, _ = _format_sequence(None, l_outputs, layout, merge_outputs)
            r_outputs, _, _, _ = _format_sequence(None, r_outputs, layout, merge_outputs)
        if merge_outputs:
            r_outputs = F.reverse(r_outputs, axis=axis)
            outputs = F.concat(l_outputs, r_outputs, dim=2, name='%sout'%self._output_prefix)
        else:
            outputs = [F.concat(l_o, r_o, dim=1, name='%st%d'%(self._output_prefix, i))
                       for i, (l_o, r_o) in enumerate(zip(l_outputs, reversed(r_outputs)))]
        states = l_states + r_states
        return outputs, states