Source code for mxnet.name
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
"""Automatic naming support for symbolic API."""
from __future__ import absolute_import
import threading
import warnings
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
[docs]class NameManager(with_metaclass(_MXClassPropertyMetaClass, object)):
"""NameManager to do automatic naming.
Developers can also inherit from this class to change naming behavior.
"""
_current = threading.local()
def __init__(self):
self._counter = {}
self._old_manager = None
[docs] def get(self, name, hint):
"""Get the canonical name for a symbol.
This is the default implementation.
If the user specifies a name,
the user-specified name will be used.
When user does not specify a name, we automatically generate a
name based on the hint string.
Parameters
----------
name : str or None
The name specified by the user.
hint : str
A hint string, which can be used to generate name.
Returns
-------
full_name : str
A canonical name for the symbol.
"""
if name:
return name
if hint not in self._counter:
self._counter[hint] = 0
name = '%s%d' % (hint, self._counter[hint])
self._counter[hint] += 1
return name
def __enter__(self):
if not hasattr(NameManager._current, "value"):
NameManager._current.value = NameManager()
self._old_manager = NameManager._current.value
NameManager._current.value = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_manager
NameManager._current.value = self._old_manager
#pylint: disable=no-self-argument
@classproperty
def current(cls):
warnings.warn("NameManager.current has been deprecated. "
"It is advised to use the `with` statement with NameManager.",
DeprecationWarning)
if not hasattr(NameManager._current, "value"):
cls._current.value = NameManager()
return cls._current.value
@current.setter
def current(cls, val):
cls._current.value = val
#pylint: enable=no-self-argument
[docs]class Prefix(NameManager):
"""A name manager that attaches a prefix to all names.
Examples
--------
>>> import mxnet as mx
>>> data = mx.symbol.Variable('data')
>>> with mx.name.Prefix('mynet_'):
net = mx.symbol.FullyConnected(data, num_hidden=10, name='fc1')
>>> net.list_arguments()
['data', 'mynet_fc1_weight', 'mynet_fc1_bias']
"""
def __init__(self, prefix):
super(Prefix, self).__init__()
self._prefix = prefix
[docs] def get(self, name, hint):
name = super(Prefix, self).get(name, hint)
return self._prefix + name
# initialize the default name manager
NameManager._current.value = NameManager()