# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=too-many-arguments, no-member
"""Functions for constructing recurrent neural networks."""
import warnings
from ..model import save_checkpoint, load_checkpoint
from .rnn_cell import BaseRNNCell
def rnn_unroll(cell, length, inputs=None, begin_state=None, input_prefix='', layout='NTC'):
"""Deprecated. Please use cell.unroll instead"""
warnings.warn('rnn_unroll is deprecated. Please call cell.unroll directly.')
return cell.unroll(length=length, inputs=inputs, begin_state=begin_state,
input_prefix=input_prefix, layout=layout)
def save_rnn_checkpoint(cells, prefix, epoch, symbol, arg_params, aux_params):
"""Save checkpoint for model using RNN cells.
Unpacks weight before saving.
Parameters
----------
cells : mxnet.rnn.RNNCell or list of RNNCells
The RNN cells used by this symbol.
prefix : str
Prefix of model name.
epoch : int
The epoch number of the model.
symbol : Symbol
The input symbol
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
Notes
-----
- ``prefix-symbol.json`` will be saved for symbol.
- ``prefix-epoch.params`` will be saved for parameters.
"""
if isinstance(cells, BaseRNNCell):
cells = [cells]
for cell in cells:
arg_params = cell.unpack_weights(arg_params)
save_checkpoint(prefix, epoch, symbol, arg_params, aux_params)
def load_rnn_checkpoint(cells, prefix, epoch):
"""Load model checkpoint from file.
Pack weights after loading.
Parameters
----------
cells : mxnet.rnn.RNNCell or list of RNNCells
The RNN cells used by this symbol.
prefix : str
Prefix of model name.
epoch : int
Epoch number of model we would like to load.
Returns
-------
symbol : Symbol
The symbol configuration of computation network.
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
Notes
-----
- symbol will be loaded from ``prefix-symbol.json``.
- parameters will be loaded from ``prefix-epoch.params``.
"""
sym, arg, aux = load_checkpoint(prefix, epoch)
if isinstance(cells, BaseRNNCell):
cells = [cells]
for cell in cells:
arg = cell.pack_weights(arg)
return sym, arg, aux
def do_rnn_checkpoint(cells, prefix, period=1):
"""Make a callback to checkpoint Module to prefix every epoch.
unpacks weights used by cells before saving.
Parameters
----------
cells : mxnet.rnn.RNNCell or list of RNNCells
The RNN cells used by this symbol.
prefix : str
The file prefix to checkpoint to
period : int
How many epochs to wait before checkpointing. Default is 1.
Returns
-------
callback : function
The callback function that can be passed as iter_end_callback to fit.
"""
period = int(max(1, period))
# pylint: disable=unused-argument
def _callback(iter_no, sym=None, arg=None, aux=None):
"""The checkpoint function."""
if (iter_no + 1) % period == 0:
save_rnn_checkpoint(cells, prefix, iter_no+1, sym, arg, aux)
return _callback