# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable= arguments-differ
"""Base container class for all neural network models."""
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']
import copy
import warnings
import re
from .. import symbol, ndarray, initializer
from ..symbol import Symbol
from ..ndarray import NDArray
from .. import name as _name
from .parameter import Parameter, ParameterDict, DeferredInitializationError
from .utils import _indent
class _BlockScope(object):
"""Scope for collecting child `Block` s."""
_current = None
def __init__(self, block):
self._block = block
self._counter = {}
self._old_scope = None
self._name_scope = None
@staticmethod
def create(prefix, params, hint):
"""Creates prefix and params for new `Block`."""
current = _BlockScope._current
if current is None:
if prefix is None:
prefix = _name.NameManager.current.get(None, hint) + '_'
if params is None:
params = ParameterDict(prefix)
else:
params = ParameterDict(params.prefix, params)
return prefix, params
if prefix is None:
count = current._counter.get(hint, 0)
prefix = '%s%d_'%(hint, count)
current._counter[hint] = count + 1
if params is None:
parent = current._block.params
params = ParameterDict(parent.prefix+prefix, parent._shared)
else:
params = ParameterDict(params.prefix, params)
return current._block.prefix+prefix, params
def __enter__(self):
if self._block._empty_prefix:
return
self._old_scope = _BlockScope._current
_BlockScope._current = self
self._name_scope = _name.Prefix(self._block.prefix)
self._name_scope.__enter__()
return self
def __exit__(self, ptype, value, trace):
if self._block._empty_prefix:
return
self._name_scope.__exit__(ptype, value, trace)
self._name_scope = None
_BlockScope._current = self._old_scope
def _flatten(args):
if isinstance(args, NDArray):
return [args], int(0)
if isinstance(args, Symbol):
length = len(args.list_outputs())
length = length if length > 1 else 0
return [args], int(length)
assert isinstance(args, (list, tuple)), \
"HybridBlock input must be (nested) list of Symbol or NDArray, " \
"but got %s of type %s"%(str(args), str(type(args)))
flat = []
fmts = []
for i in args:
arg, fmt = _flatten(i)
flat.extend(arg)
fmts.append(fmt)
return flat, fmts
def _regroup(args, fmt):
if isinstance(fmt, int):
if fmt == 0:
return args[0], args[1:]
return args[:fmt], args[fmt:]
assert isinstance(args, (list, tuple)), \
"HybridBlock output must be (nested) list of Symbol or NDArray, " \
"but got %s of type %s"%(str(args), str(type(args)))
ret = []
for i in fmt:
res, args = _regroup(args, i)
ret.append(res)
return ret, args
[docs]class Block(object):
"""Base class for all neural network layers and models. Your models should
subclass this class.
:py:class:`Block` can be nested recursively in a tree structure. You can create and
assign child :py:class:`Block` as regular attributes::
from mxnet.gluon import Block, nn
from mxnet import ndarray as F
class Model(Block):
def __init__(self, **kwargs):
super(Model, self).__init__(**kwargs)
# use name_scope to give child Blocks appropriate names.
# It also allows sharing Parameters between Blocks recursively.
with self.name_scope():
self.dense0 = nn.Dense(20)
self.dense1 = nn.Dense(20)
def forward(self, x):
x = F.relu(self.dense0(x))
return F.relu(self.dense1(x))
model = Model()
model.initialize(ctx=mx.cpu(0))
model(F.zeros((10, 10), ctx=mx.cpu(0)))
Child :py:class:`Block` assigned this way will be registered and :py:meth:`collect_params`
will collect their Parameters recursively.
Parameters
----------
prefix : str
Prefix acts like a name space. It will be prepended to the names of all
Parameters and child :py:class:`Block` s in this :py:class:`Block` 's
:py:meth:`name_scope` .
Prefix should be unique within one model to prevent name collisions.
params : ParameterDict or None
:py:class:`ParameterDict` for sharing weights with the new :py:class:`Block`. For example,
if you want ``dense1`` to share ``dense0``'s weights, you can do::
dense0 = nn.Dense(20)
dense1 = nn.Dense(20, params=dense0.collect_params())
"""
def __init__(self, prefix=None, params=None):
self._empty_prefix = prefix == ''
self._prefix, self._params = _BlockScope.create(prefix, params, self._alias())
self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix
self._scope = _BlockScope(self)
self._children = []
def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join([' ({key}): {block}'.format(key=key,
block=_indent(block.__repr__(), 2))
for key, block in self.__dict__.items() if isinstance(block, Block)])
return s.format(name=self.__class__.__name__,
modstr=modstr)
[docs] def __setattr__(self, name, value):
"""Registers parameters."""
if hasattr(self, name):
existing = getattr(self, name)
if isinstance(existing, (Parameter, Block)) and not isinstance(value, type(existing)):
raise TypeError('Changing attribute type for {name} from {type1} to {type2}' \
'is not allowed.'.format(name=name,
type1=type(existing),
type2=type(value)))
if isinstance(existing, Block):
for i, c in enumerate(self._children):
if c is existing:
self._children[i] = value
elif isinstance(value, Block):
self.register_child(value)
elif isinstance(value, Block):
self.register_child(value)
super(Block, self).__setattr__(name, value)
def _check_container_with_block(self):
def _find_block_in_container(data):
# Find whether a nested container structure contains Blocks
if isinstance(data, (list, tuple)):
for ele in data:
if _find_block_in_container(ele):
return True
return False
elif isinstance(data, dict):
for _, v in data.items():
if _find_block_in_container(v):
return True
return False
elif isinstance(data, Block):
return True
else:
return False
for k, v in self.__dict__.items():
if isinstance(v, (list, tuple, dict)) and not (k.startswith('__') or k == '_children'):
if _find_block_in_container(v):
warnings.warn('"{name}" is a container with Blocks. '
'Note that Blocks inside the list, tuple or dict will not be '
'registered automatically. Make sure to register them using '
'register_child() or switching to '
'nn.Sequential/nn.HybridSequential instead. '
.format(name=self.__class__.__name__ + "." + k))
def _alias(self):
return self.__class__.__name__.lower()
@property
def prefix(self):
"""Prefix of this :py:class:`Block`."""
return self._prefix
@property
def name(self):
"""Name of this :py:class:`Block`, without '_' in the end."""
return self._name
[docs] def name_scope(self):
"""Returns a name space object managing a child :py:class:`Block` and parameter
names. Should be used within a ``with`` statement::
with self.name_scope():
self.dense = nn.Dense(20)
"""
return self._scope
@property
def params(self):
"""Returns this :py:class:`Block`'s parameter dictionary (does not include its
children's parameters)."""
return self._params
[docs] def collect_params(self, select=None):
"""Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its
children's Parameters(default), also can returns the select :py:class:`ParameterDict`
which match some given regular expressions.
For example, collect the specified parameter in ['conv1_weight', 'conv1_bias', 'fc_weight',
'fc_bias']::
model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
or collect all paramters which their name ends with 'weight' or 'bias', this can be done
using regular expressions::
model.collect_params('.*weight|.*bias')
Parameters
----------
select : str
regular expressions
Returns
-------
The selected :py:class:`ParameterDict`
"""
# We need to check here because blocks inside containers are not supported.
self._check_container_with_block()
ret = ParameterDict(self._params.prefix)
if not select:
ret.update(self.params)
else:
pattern = re.compile(select)
ret.update({name:value for name, value in self.params.items() if pattern.match(name)})
for cld in self._children:
ret.update(cld.collect_params(select=select))
return ret
[docs] def save_params(self, filename):
"""Save parameters to file.
filename : str
Path to file.
"""
self.collect_params().save(filename, strip_prefix=self.prefix)
[docs] def load_params(self, filename, ctx, allow_missing=False,
ignore_extra=False):
"""Load parameters from file.
filename : str
Path to parameter file.
ctx : Context or list of Context
Context(s) initialize loaded parameters on.
allow_missing : bool, default False
Whether to silently skip loading parameters not represents in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.
"""
self.collect_params().load(filename, ctx, allow_missing, ignore_extra,
self.prefix)
[docs] def register_child(self, block):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
attributes will be registered automatically."""
self._children.append(block)
[docs] def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False):
"""Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children.
Equivalent to ``block.collect_params().initialize(...)``
"""
self.collect_params().initialize(init, ctx, verbose)
[docs] def hybridize(self, active=True, **kwargs):
"""Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
non-hybrid children.
Parameters
----------
active : bool, default True
Whether to turn hybrid on or off.
**kwargs : string
Additional flags for hybridized operator.
"""
for cld in self._children:
cld.hybridize(active, **kwargs)
[docs] def cast(self, dtype):
"""Cast this Block to use another data type.
Parameters
----------
dtype : str or numpy.dtype
The new data type.
"""
for child in self._children:
child.cast(dtype)
for _, param in self.params.items():
param.cast(dtype)
[docs] def __call__(self, *args):
"""Calls forward. Only accepts positional arguments."""
return self.forward(*args)
[docs] def forward(self, *args):
"""Overrides to implement forward computation using :py:class:`NDArray`. Only
accepts positional arguments.
Parameters
----------
*args : list of NDArray
Input tensors.
"""
# pylint: disable= invalid-name
raise NotImplementedError
[docs]class HybridBlock(Block):
"""`HybridBlock` supports forwarding with both Symbol and NDArray.
Forward computation in :py:class:`HybridBlock` must be static to work with :py:class:`Symbol` s,
i.e. you cannot call :py:meth:`NDArray.asnumpy`, :py:attr:`NDArray.shape`,
:py:attr:`NDArray.dtype`, etc on tensors.
Also, you cannot use branching or loop logic that bases on non-constant
expressions like random numbers or intermediate results, since they change
the graph structure for each iteration.
Before activating with :py:meth:`hybridize()`, :py:class:`HybridBlock` works just like normal
:py:class:`Block`. After activation, :py:class:`HybridBlock` will create a symbolic graph
representing the forward computation and cache it. On subsequent forwards,
the cached graph will be used instead of :py:meth:`hybrid_forward`.
Refer `Hybrid tutorial `_ to see
the end-to-end usage.
"""
def __init__(self, prefix=None, params=None):
super(HybridBlock, self).__init__(prefix=prefix, params=params)
self._reg_params = {}
self._cached_graph = ()
self._cached_op = None
self._cached_op_args = None
self._out_format = None
self._in_format = None
self._active = False
self._flags = {}
[docs] def __setattr__(self, name, value):
"""Registers parameters."""
super(HybridBlock, self).__setattr__(name, value)
if isinstance(value, HybridBlock):
self._clear_cached_op()
if isinstance(value, Parameter):
assert name not in self._reg_params or \
not isinstance(self._reg_params[name], Parameter), \
"Overriding Parameter attribute %s is not allowed. " \
"Please pass in Parameters by specifying `params` at " \
"Block construction instead."
self._reg_params[name] = value
def _get_graph(self, *args):
if not self._cached_graph:
args, self._in_format = _flatten(args)
if len(args) > 1:
inputs = [symbol.var('data%d'%i) for i in range(len(args))]
else:
inputs = [symbol.var('data')]
grouped_inputs = _regroup(inputs, self._in_format)[0]
params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter
out, self._out_format = _flatten(out)
self._cached_graph = inputs, symbol.Group(out)
return self._cached_graph
def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
input_idx = {var.name: i for i, var in enumerate(inputs)}
self._cached_op = ndarray.CachedOp(out, self._flags)
params = dict(self.collect_params().items())
# verify graph inputs
expected_inputs = set(out.list_inputs())
for name in expected_inputs:
assert name in params or name in input_idx, \
"Unknown input to HybridBlock: %s"%name
for name, i in input_idx.items():
if name not in expected_inputs:
warnings.warn("The %d-th input to HybridBlock is not used by any "
"computation. Is this intended?"%i)
for name in params:
if name not in expected_inputs:
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%name)
self._cached_op_args = [(False, params[name]) if name in params
else (True, input_idx[name])
for name in out.list_inputs()]
def _finish_deferred_init(self, hybrid, *args):
self.infer_shape(*args)
if hybrid:
for is_arg, i in self._cached_op_args:
if not is_arg:
i._finish_deferred_init()
else:
for _, i in self.params.items():
i._finish_deferred_init()
def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)
args, fmt = _flatten(args)
assert fmt == self._in_format, "Invalid input format"
cargs = [args[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)[0]
def _clear_cached_op(self):
self._cached_graph = ()
self._cached_op = None
self._cached_op_args = None
def register_child(self, block):
if not isinstance(block, HybridBlock):
raise ValueError(
"Children of HybridBlock must also be HybridBlock, " \
"but %s has type %s. If you are using Sequential, " \
"please try HybridSequential instead"%(
str(block), str(type(block))))
super(HybridBlock, self).register_child(block)
self._clear_cached_op()
def hybridize(self, active=True, **kwargs):
self._active = active
self._flags = kwargs.items()
self._clear_cached_op()
super(HybridBlock, self).hybridize(active, **kwargs)
def cast(self, dtype):
self._clear_cached_op()
super(HybridBlock, self).cast(dtype)
def _infer_attrs(self, infer_fn, attr, *args):
"""Generic infer attributes."""
inputs, out = self._get_graph(*args)
args, _ = _flatten(args)
arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
**{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
sdict.update({name : attr for name, attr in \
zip(out.list_auxiliary_states(), aux_attrs)})
for i in self.collect_params().values():
setattr(i, attr, sdict[i.name])
[docs] def infer_shape(self, *args):
"""Infers shape of Parameters from inputs."""
self._infer_attrs('infer_shape', 'shape', *args)
[docs] def infer_type(self, *args):
"""Infers data type of Parameters from inputs."""
self._infer_attrs('infer_type', 'dtype', *args)
[docs] def export(self, path, epoch=0):
"""Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
or the C++ interface.
.. note:: When there are only one input, it will have name `data`. When there
Are more than one inputs, they will be named as `data0`, `data1`, etc.
Parameters
----------
path : str
Path to save model. Two files `path-symbol.json` and `path-xxxx.params`
will be created, where xxxx is the 4 digits epoch number.
epoch : int
Epoch number of saved model.
"""
if not self._cached_graph:
raise RuntimeError(
"Please first call block.hybridize() and then run forward with "
"this block at least once before calling export.")
sym = self._cached_graph[1]
sym.save('%s-symbol.json'%path)
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
arg_dict = {}
for name, param in self.collect_params().items():
if name in arg_names:
arg_dict['arg:%s'%name] = param._reduce()
else:
assert name in aux_names
arg_dict['aux:%s'%name] = param._reduce()
ndarray.save('%s-%04d.params'%(path, epoch), arg_dict)
[docs] def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
try:
if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
except DeferredInitializationError:
self._finish_deferred_init(self._active, x, *args)
if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
return self.hybrid_forward(ndarray, x, *args, **params)
assert isinstance(x, Symbol), \
"HybridBlock requires the first argument to forward be either " \
"Symbol or NDArray, but got %s"%type(x)
params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
return self.hybrid_forward(symbol, x, *args, **params)
[docs] def hybrid_forward(self, F, x, *args, **kwargs):
"""Overrides to construct symbolic graph for this `Block`.
Parameters
----------
x : Symbol or NDArray
The first input tensor.
*args : list of Symbol or list of NDArray
Additional input tensors.
"""
# pylint: disable= invalid-name
raise NotImplementedError
[docs]class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
as feature extractors. For example, you may want to extract get the output
from fc2 layer in AlexNet.
Parameters
----------
outputs : Symbol or list of Symbol
The desired output for SymbolBlock.
inputs : Symbol or list of Symbol
The Variables in output's argument that should be used as inputs.
params : ParameterDict
Parameter dictionary for arguments and auxililary states of outputs
that are not inputs.
Examples
--------
>>> # To extract the feature from fc1 and fc2 layers of AlexNet:
>>> alexnet = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=mx.cpu(),
prefix='model_')
>>> inputs = mx.sym.var('data')
>>> out = alexnet(inputs)
>>> internals = out.get_internals()
>>> print(internals.list_outputs())
['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...]
>>> outputs = [internals['model_dense0_relu_fwd_output'],
internals['model_dense1_relu_fwd_output']]
>>> # Create SymbolBlock that shares parameters with alexnet
>>> feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params())
>>> x = mx.nd.random.normal(shape=(16, 3, 224, 224))
>>> print(feat_model(x))
"""
def __init__(self, outputs, inputs, params=None):
super(SymbolBlock, self).__init__(prefix=None, params=None)
self._prefix = ''
self._params = ParameterDict('', params)
if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) == 1:
inputs = [inputs]
if isinstance(outputs, (list, tuple)) and len(outputs) == 1:
outputs = outputs[0]
syms, self._in_format = _flatten(inputs)
out, self._out_format = _flatten(outputs)
out = symbol.Group(out)
input_names = set()
for i in syms:
assert len(i.get_internals().list_outputs()) == 1, \
"Input symbols must be variable, but %s is an output of operators"%str(i)
input_names.add(i.name)
for i in out.list_arguments():
if i not in input_names:
self.params.get(i, allow_deferred_init=True)
for i in out.list_auxiliary_states():
if i not in input_names:
self.params.get(i, grad_req='null', allow_deferred_init=True)
self._cached_graph = syms, out
self._build_cache()
def forward(self, x, *args):
if isinstance(x, NDArray):
with x.context:
try:
return self._call_cached_op(x, *args)
except DeferredInitializationError:
self._finish_deferred_init(True, x, *args)
return self._call_cached_op(x, *args)
assert isinstance(x, Symbol), \
"HybridBlock requires the first argument to forward be either " \
"Symbol or NDArray, but got %s"%type(x)
args, in_fmt = _flatten([x] + list(args))
assert in_fmt == self._in_format, "Invalid input format"
ret = copy.copy(self._cached_graph[1])
ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)})
return _regroup(list(ret), self._out_format)[0]
def _clear_cached_op(self):
tmp = self._cached_graph
super(SymbolBlock, self)._clear_cached_op()
self._cached_graph = tmp
def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError