Source code for mxnet.registry
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=no-member
"""Registry for serializable objects."""
from __future__ import absolute_import
import json
import warnings
from .base import string_types
_REGISTRY = {}
[docs]def get_registry(base_class):
"""Get a copy of the registry.
Parameters
----------
base_class : type
base class for classes that will be registered.
Returns
-------
a registrator
"""
if base_class not in _REGISTRY:
_REGISTRY[base_class] = {}
return _REGISTRY[base_class].copy()
[docs]def get_register_func(base_class, nickname):
"""Get registrator function.
Parameters
----------
base_class : type
base class for classes that will be reigstered
nickname : str
nickname of base_class for logging
Returns
-------
a registrator function
"""
if base_class not in _REGISTRY:
_REGISTRY[base_class] = {}
registry = _REGISTRY[base_class]
def register(klass, name=None):
"""Register functions"""
assert issubclass(klass, base_class), \
"Can only register subclass of %s"%base_class.__name__
if name is None:
name = klass.__name__
name = name.lower()
if name in registry:
warnings.warn(
"\033[91mNew %s %s.%s registered with name %s is"
"overriding existing %s %s.%s\033[0m"%(
nickname, klass.__module__, klass.__name__, name,
nickname, registry[name].__module__, registry[name].__name__),
UserWarning, stacklevel=2)
registry[name] = klass
return klass
register.__doc__ = "Register %s to the %s factory"%(nickname, nickname)
return register
[docs]def get_alias_func(base_class, nickname):
"""Get registrator function that allow aliases.
Parameters
----------
base_class : type
base class for classes that will be reigstered
nickname : str
nickname of base_class for logging
Returns
-------
a registrator function
"""
register = get_register_func(base_class, nickname)
def alias(*aliases):
"""alias registrator"""
def reg(klass):
"""registrator function"""
for name in aliases:
register(klass, name)
return klass
return reg
return alias
[docs]def get_create_func(base_class, nickname):
"""Get creator function
Parameters
----------
base_class : type
base class for classes that will be reigstered
nickname : str
nickname of base_class for logging
Returns
-------
a creator function
"""
if base_class not in _REGISTRY:
_REGISTRY[base_class] = {}
registry = _REGISTRY[base_class]
def create(*args, **kwargs):
"""Create instance from config"""
if len(args):
name = args[0]
args = args[1:]
else:
name = kwargs.pop(nickname)
if isinstance(name, base_class):
assert len(args) == 0 and len(kwargs) == 0, \
"%s is already an instance. Additional arguments are invalid"%(nickname)
return name
if isinstance(name, dict):
return create(**name)
assert isinstance(name, string_types), "%s must be of string type"%nickname
if name.startswith('['):
assert not args and not kwargs
name, kwargs = json.loads(name)
return create(name, **kwargs)
elif name.startswith('{'):
assert not args and not kwargs
kwargs = json.loads(name)
return create(**kwargs)
name = name.lower()
assert name in registry, \
"%s is not registered. Please register with %s.register first"%(
str(name), nickname)
return registry[name](*args, **kwargs)
create.__doc__ = """Create a %s instance from config.
Parameters
----------
%s : str or %s instance
class name of desired instance. If is a instance,
it will be returned directly.
**kwargs : dict
arguments to be passed to constructor"""%(nickname, nickname, base_class.__name__)
return create