Source code for mxnet.name

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Automatic naming support for symbolic API."""
import threading
import warnings
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass

[docs]class NameManager(with_metaclass(_MXClassPropertyMetaClass, object)): """NameManager to do automatic naming. Developers can also inherit from this class to change naming behavior. """ _current = threading.local() def __init__(self): self._counter = {} self._old_manager = None
[docs] def get(self, name, hint): """Get the canonical name for a symbol. This is the default implementation. If the user specifies a name, the user-specified name will be used. When user does not specify a name, we automatically generate a name based on the hint string. Parameters ---------- name : str or None The name specified by the user. hint : str A hint string, which can be used to generate name. Returns ------- full_name : str A canonical name for the symbol. """ if name: return name if hint not in self._counter: self._counter[hint] = 0 name = '%s%d' % (hint, self._counter[hint]) self._counter[hint] += 1 return name
def __enter__(self): if not hasattr(NameManager._current, "value"): NameManager._current.value = NameManager() self._old_manager = NameManager._current.value NameManager._current.value = self return self def __exit__(self, ptype, value, trace): assert self._old_manager NameManager._current.value = self._old_manager #pylint: disable=no-self-argument @classproperty def current(cls): warnings.warn("NameManager.current has been deprecated. " "It is advised to use the `with` statement with NameManager.", DeprecationWarning) if not hasattr(NameManager._current, "value"): cls._current.value = NameManager() return cls._current.value @current.setter def current(cls, val): cls._current.value = val
#pylint: enable=no-self-argument
[docs]class Prefix(NameManager): """A name manager that attaches a prefix to all names. Examples -------- >>> import mxnet as mx >>> data = mx.symbol.Variable('data') >>> with mx.name.Prefix('mynet_'): net = mx.symbol.FullyConnected(data, num_hidden=10, name='fc1') >>> net.list_arguments() ['data', 'mynet_fc1_weight', 'mynet_fc1_bias'] """ def __init__(self, prefix): super(Prefix, self).__init__() self._prefix = prefix
[docs] def get(self, name, hint): name = super(Prefix, self).get(name, hint) return self._prefix + name
# initialize the default name manager NameManager._current.value = NameManager()