Source code for mxnet.gluon.contrib.estimator.estimator
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=wildcard-import, unused-variable
"""Gluon Estimator"""
import copy
import logging
import sys
import warnings
from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler, GradientUpdateHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .event_handler import _check_event_handlers
from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref
from ...data import DataLoader
from ...loss import Loss as gluon_loss
from ...trainer import Trainer
from ...utils import split_and_load
from ....device import Device, cpu, gpu, num_gpus
from ...metric import Loss as metric_loss
from .batch_processor import BatchProcessor
__all__ = ['Estimator']
[docs]class Estimator(object):
"""Estimator Class for easy model training
:py:class:`Estimator` can be used to facilitate the training & validation process
Parameters
----------
net : gluon.Block
The model used for training.
loss : gluon.loss.Loss
Loss (objective) function to calculate during training.
train_metrics : EvalMetric or list of EvalMetric
Training metrics for evaluating models on training dataset.
val_metrics : EvalMetric or list of EvalMetric
Validation metrics for evaluating models on validation dataset.
initializer : Initializer
Initializer to initialize the network.
trainer : Trainer
Trainer to apply optimizer on network parameters.
device : Device or list of Device
Device(s) to run the training on.
val_net : gluon.Block
The model used for validation. The validation model does not necessarily belong to
the same model class as the training model. But the two models typically share the
same architecture. Therefore the validation model can reuse parameters of the
training model.
The code example of consruction of val_net sharing the same network parameters as
the training net is given below:
>>> net = _get_train_network()
>>> val_net = _get_test_network()
>>> val_net.share_parameters(net.collect_params())
>>> net.initialize(device=device)
>>> est = Estimator(net, loss, val_net=val_net)
Proper namespace match is required for weight sharing between two networks. Most networks
inheriting :py:class:`Block` can share their parameters correctly. An exception is
Sequential networks that Block scope must be specified for correct weight sharing. For
the naming in mxnet Gluon API, please refer to the site
(https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html)
for future information.
val_loss : gluon.loss.loss
Loss (objective) function to calculate during validation. If set val_loss
None, it will use the same loss function as self.loss
batch_processor: BatchProcessor
BatchProcessor provides customized fit_batch() and evaluate_batch() methods
"""
logger = None
"""logging.Logger object associated with the Estimator.
The logger is used for all logs generated by this estimator and its
handlers. A new logging.Logger is created during Estimator construction and
configured to write all logs with level logging.INFO or higher to
sys.stdout.
You can modify the logging settings using the standard Python methods. For
example, to save logs to a file in addition to printing them to stdout
output, you can attach a logging.FileHandler to the logger.
>>> est = Estimator(net, loss)
>>> import logging
>>> est.logger.addHandler(logging.FileHandler(filename))
"""
def __init__(self, net,
loss,
train_metrics=None,
val_metrics=None,
initializer=None,
trainer=None,
device=None,
val_net=None,
val_loss=None,
batch_processor=None):
self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(train_metrics)
self._val_metrics = _check_metrics(val_metrics)
self._add_default_training_metrics()
self._add_validation_metrics()
self.val_loss = self.loss
if val_loss is not None:
self.val_loss = self._check_loss(val_loss)
self.val_net = self.net
if val_net is not None:
self.val_net = val_net
self.logger = logging.Logger(name='Estimator', level=logging.INFO)
self.logger.addHandler(logging.StreamHandler(sys.stdout))
self.device = self._check_devices(device)
self._initialize(initializer)
self.trainer = self._check_trainer(trainer)
self.batch_processor = self._check_batch_processor(batch_processor)
def _check_loss(self, loss):
if not isinstance(loss, gluon_loss):
raise ValueError("loss must be a Loss, "
"refer to gluon.loss.Loss:{}".format(loss))
return loss
def _check_context(self, context):
"""This function has been deprecated. Please refer to ``Estimator._check_devices``."""
warnings.warn('Estimator._check_context has been renamed to'
' Estimator._check_devices', DeprecationWarning)
return self._check_devices(context)
def _check_devices(self, devices):
# infer available devices
gpus = num_gpus()
available_gpus = [gpu(i) for i in range(gpus)]
if devices:
# check devices values, only accept Device or a list of Device
if isinstance(devices, Device):
devices = [devices]
elif isinstance(devices, list) and all([isinstance(c, Device) for c in devices]):
devices = devices
else:
raise ValueError("devices must be a Device or a list of Device, "
"for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], "
"refer to mxnet.Device:{}".format(devices))
for device in devices:
assert device in available_gpus or str(device).startswith('cpu'), \
"{} is not available, please make sure " \
"your device is in one of: mx.cpu(), {}".format(
device, ', '.join([str(device) for device in available_gpus]))
else:
# provide default device
if gpus > 0:
# only use 1 GPU by default
if gpus > 1:
warnings.warn("You have multiple GPUs, gpu(0) will be used by default."
"To utilize all your GPUs, specify device as a list of gpus, "
"e.g. devices=[mx.gpu(0), mx.gpu(1)] ")
devices = [gpu(0)]
else:
devices = [cpu()]
return devices
def _check_batch_processor(self, batch_processor):
# check whether the batch processor contains fit_batch() and evaluate_batch() methods
if batch_processor is not None:
model_fit = getattr(batch_processor, 'fit_batch', None)
model_evaluate = getattr(batch_processor, 'evaluate_batch', None)
if not callable(model_fit) or not callable(model_evaluate):
raise ValueError('Customized Batch Processor must contain fit_batch()'
' and evaluate_batch() methods')
else:
batch_processor = BatchProcessor()
return batch_processor
def _initialize(self, initializer):
# initialize the network
if not self._is_initialized():
# net is partially or not initialized,
# initialize with user specified initializer
# if initializer is None, default initializer will be used
# do not re-init layers already initialized
if initializer:
self.net.initialize(init=initializer, device=self.device)
else:
self.net.initialize(device=self.device)
elif initializer:
# net is fully initialized, and user passed not None initializer
# do not force reinitialize, give warning
warnings.warn("Network already fully initialized, skipping initialization. "
"You don't need to pass initializer if you already "
"initialized your net. "
"You can use net.initialize(init=your_initializer, force_reinit=True)"
"to force re-initialize.")
def _check_trainer(self, trainer):
# handle trainer
if not trainer:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
trainer = Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})
elif not isinstance(trainer, Trainer):
raise ValueError("Trainer must be a Gluon Trainer instance, refer to "
"gluon.Trainer:{}".format(trainer))
return trainer
def _is_initialized(self):
param_dict = self.net.collect_params()
for param in param_dict:
try:
param_dict[param].list_device()
except RuntimeError:
return False
return True
def _get_data_and_label(self, batch, device, batch_axis=0):
data = batch[0]
label = batch[1]
data = split_and_load(data, device, batch_axis=batch_axis)
label = split_and_load(label, device, batch_axis=batch_axis)
return data, label
def _add_default_training_metrics(self):
if not self._train_metrics:
suggested_metric = _suggest_metric_for_loss(self.loss)
if suggested_metric:
self._train_metrics = [suggested_metric]
loss_name = type(self.loss).__name__
self._train_metrics.append(metric_loss(loss_name))
for metric in self._train_metrics:
# add training prefix to the metric name
# it is useful for event handlers to distinguish them from validation metrics
metric.name = 'training ' + metric.name
def _add_validation_metrics(self):
if not self._val_metrics:
self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]
for metric in self._val_metrics:
# add validation prefix to the metric name
# it is useful for event handlers to distinguish them from training metrics
if 'training' in metric.name:
metric.name = metric.name.replace('training', 'validation')
else:
metric.name = 'validation ' + metric.name
@property
def train_metrics(self):
return self._train_metrics
@property
def val_metrics(self):
return self._val_metrics
[docs] def evaluate(self,
val_data,
batch_axis=0,
event_handlers=None):
"""Evaluate model on validation data.
This function calls :py:func:`evaluate_batch` on each of the batches from the
validation data loader. Thus, for custom use cases, it's possible to inherit the
estimator class and override :py:func:`evaluate_batch`.
Parameters
----------
val_data : DataLoader
Validation data loader with data and labels.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
event_handlers : EventHandler or list of EventHandler
List of :py:class:`EventHandlers` to apply during validation. Besides
event handlers specified here, a default MetricHandler and a LoggingHandler
will be added if not specified explicitly.
"""
if not isinstance(val_data, DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.DataLoader")
for metric in self.val_metrics:
metric.reset()
estimator_ref = self
event_handlers = self._prepare_default_validation_handlers(event_handlers)
_, epoch_begin, batch_begin, batch_end, \
epoch_end, _ = self._categorize_handlers(event_handlers)
estimator_ref = self
for handler in epoch_begin:
handler.epoch_begin(estimator_ref)
for _, batch in enumerate(val_data):
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)
_, label, pred, loss = \
self.batch_processor.evaluate_batch(estimator_ref, batch,
batch_axis)
for handler in batch_end:
handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss)
for handler in epoch_end:
handler.epoch_end(estimator_ref)
[docs] def fit(self, train_data,
val_data=None,
epochs=None,
event_handlers=None,
batches=None,
batch_axis=0):
"""Trains the model with a given :py:class:`DataLoader` for a specified
number of epochs or batches. The batch size is inferred from the
data loader's batch_size.
This function calls :py:func:`fit_batch` on each of the batches from the
training data loader. Thus, for custom use cases, it's possible to inherit the
estimator class and override :py:func:`fit_batch`.
Parameters
----------
train_data : DataLoader
Training data loader with data and labels.
val_data : DataLoader, default None
Validation data loader with data and labels.
epochs : int, default None
Number of epochs to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
event_handlers : EventHandler or list of EventHandler
List of :py:class:`EventHandlers` to apply during training. Besides
the event handlers specified here, a StoppingHandler,
LoggingHandler and MetricHandler will be added by default if not
yet specified manually. If validation data is provided, a
ValidationHandler is also added if not already specified.
batches : int, default None
Number of batches to iterate on the training data.
You can only specify one and only one type of iteration(epochs or batches).
batch_axis : int, default 0
Batch axis to split the training data into devices.
"""
if not isinstance(train_data, DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.dataloader")
# must specify one and only one of epochs or batches
if (not epochs) == (not batches):
raise ValueError(
"Fit only support exactly one type of iteration, "
"train by number of epochs or number of batches."
"Please specify one and only one of: epochs or batches.")
self.max_epoch = epochs
self.max_batch = batches
self.batch_axis = batch_axis
# provide default handlers
event_handlers = self._prepare_default_handlers(val_data, event_handlers)
train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)
# pass a reference to all event handlers
estimator_ref = self
# training begin
for handler in train_begin:
handler.train_begin(estimator_ref)
while True:
# epoch begin
for handler in epoch_begin:
handler.epoch_begin(estimator_ref)
for batch in train_data:
# batch begin
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)
_, label, pred, loss = self.batch_processor.fit_batch(estimator_ref,
batch, batch_axis)
# batch end
batch_end_result = []
for handler in batch_end:
batch_end_result.append(handler.batch_end(estimator_ref, batch=batch,
pred=pred, label=label, loss=loss))
# if any handler signaled to stop
if any(batch_end_result):
break
# epoch end
epoch_end_result = []
for handler in epoch_end:
epoch_end_result.append(handler.epoch_end(estimator_ref))
# if any handler signaled to stop
if any(epoch_end_result):
break
# train end
for handler in train_end:
handler.train_end(estimator_ref)
def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers = _check_event_handlers(event_handlers)
added_default_handlers = []
# no need to add to default handler check as StoppingHandler does not use metrics
added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
if not any(isinstance(handler, GradientUpdateHandler) for handler in event_handlers):
added_default_handlers.append(GradientUpdateHandler())
if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
added_default_handlers.append(MetricHandler(metrics=self.train_metrics))
if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
# no validation handler
if val_data:
# add default validation handler if validation data found
added_default_handlers.append(ValidationHandler(val_data=val_data,
eval_fn=self.evaluate))
if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
added_default_handlers.append(LoggingHandler(metrics=self.train_metrics))
# if there is a mix of user defined event handlers and default event handlers
# they should have the same set of metrics
mixing_handlers = event_handlers and added_default_handlers
event_handlers.extend(added_default_handlers)
if mixing_handlers:
# check if all handlers have the same set of references to metrics
known_metrics = set(self.train_metrics + self.val_metrics)
for handler in event_handlers:
_check_handler_metric_ref(handler, known_metrics)
event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers
def _prepare_default_validation_handlers(self, event_handlers):
event_handlers = _check_event_handlers(event_handlers)
added_default_handlers = []
# add default logging handler and metric handler for validation
if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
added_default_handlers.append(MetricHandler(metrics=self.val_metrics))
if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
added_default_handlers.append(LoggingHandler(metrics=self.val_metrics))
mixing_handlers = event_handlers and added_default_handlers
event_handlers.extend(added_default_handlers)
# check if all handlers refer to well-defined validation metrics
if mixing_handlers:
known_metrics = set(self.val_metrics)
for handler in event_handlers:
_check_handler_metric_ref(handler, known_metrics)
event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers
def _categorize_handlers(self, event_handlers):
"""
categorize handlers into 6 event lists to avoid calling empty methods
for example, only event handlers with train_begin method
implemented will be called at train begin
"""
train_begin = []
epoch_begin = []
batch_begin = []
batch_end = []
epoch_end = []
train_end = []
for handler in event_handlers:
if isinstance(handler, TrainBegin):
train_begin.append(handler)
if isinstance(handler, EpochBegin):
epoch_begin.append(handler)
if isinstance(handler, BatchBegin):
batch_begin.append(handler)
if isinstance(handler, BatchEnd):
batch_end.append(handler)
if isinstance(handler, EpochEnd):
epoch_end.append(handler)
if isinstance(handler, TrainEnd):
train_end.append(handler)
return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end
Did this page help you?
Yes
No
Thanks for your feedback!