Source code for mxnet.gluon.contrib.estimator.estimator

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-variable
"""Gluon Estimator"""

import copy
import logging
import sys
import warnings

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler, GradientUpdateHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .event_handler import _check_event_handlers
from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref
from ...data import DataLoader
from ...loss import Loss as gluon_loss
from ...trainer import Trainer
from ...utils import split_and_load
from ....device import Device, cpu, gpu, num_gpus
from ...metric import Loss as metric_loss
from .batch_processor import BatchProcessor

__all__ = ['Estimator']


[docs]class Estimator(object): """Estimator Class for easy model training :py:class:`Estimator` can be used to facilitate the training & validation process Parameters ---------- net : gluon.Block The model used for training. loss : gluon.loss.Loss Loss (objective) function to calculate during training. train_metrics : EvalMetric or list of EvalMetric Training metrics for evaluating models on training dataset. val_metrics : EvalMetric or list of EvalMetric Validation metrics for evaluating models on validation dataset. initializer : Initializer Initializer to initialize the network. trainer : Trainer Trainer to apply optimizer on network parameters. device : Device or list of Device Device(s) to run the training on. val_net : gluon.Block The model used for validation. The validation model does not necessarily belong to the same model class as the training model. But the two models typically share the same architecture. Therefore the validation model can reuse parameters of the training model. The code example of consruction of val_net sharing the same network parameters as the training net is given below: >>> net = _get_train_network() >>> val_net = _get_test_network() >>> val_net.share_parameters(net.collect_params()) >>> net.initialize(device=device) >>> est = Estimator(net, loss, val_net=val_net) Proper namespace match is required for weight sharing between two networks. Most networks inheriting :py:class:`Block` can share their parameters correctly. An exception is Sequential networks that Block scope must be specified for correct weight sharing. For the naming in mxnet Gluon API, please refer to the site (https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html) for future information. val_loss : gluon.loss.loss Loss (objective) function to calculate during validation. If set val_loss None, it will use the same loss function as self.loss batch_processor: BatchProcessor BatchProcessor provides customized fit_batch() and evaluate_batch() methods """ logger = None """logging.Logger object associated with the Estimator. The logger is used for all logs generated by this estimator and its handlers. A new logging.Logger is created during Estimator construction and configured to write all logs with level logging.INFO or higher to sys.stdout. You can modify the logging settings using the standard Python methods. For example, to save logs to a file in addition to printing them to stdout output, you can attach a logging.FileHandler to the logger. >>> est = Estimator(net, loss) >>> import logging >>> est.logger.addHandler(logging.FileHandler(filename)) """ def __init__(self, net, loss, train_metrics=None, val_metrics=None, initializer=None, trainer=None, device=None, val_net=None, val_loss=None, batch_processor=None): self.net = net self.loss = self._check_loss(loss) self._train_metrics = _check_metrics(train_metrics) self._val_metrics = _check_metrics(val_metrics) self._add_default_training_metrics() self._add_validation_metrics() self.val_loss = self.loss if val_loss is not None: self.val_loss = self._check_loss(val_loss) self.val_net = self.net if val_net is not None: self.val_net = val_net self.logger = logging.Logger(name='Estimator', level=logging.INFO) self.logger.addHandler(logging.StreamHandler(sys.stdout)) self.device = self._check_devices(device) self._initialize(initializer) self.trainer = self._check_trainer(trainer) self.batch_processor = self._check_batch_processor(batch_processor) def _check_loss(self, loss): if not isinstance(loss, gluon_loss): raise ValueError("loss must be a Loss, " "refer to gluon.loss.Loss:{}".format(loss)) return loss def _check_context(self, context): """This function has been deprecated. Please refer to ``Estimator._check_devices``.""" warnings.warn('Estimator._check_context has been renamed to' ' Estimator._check_devices', DeprecationWarning) return self._check_devices(context) def _check_devices(self, devices): # infer available devices gpus = num_gpus() available_gpus = [gpu(i) for i in range(gpus)] if devices: # check devices values, only accept Device or a list of Device if isinstance(devices, Device): devices = [devices] elif isinstance(devices, list) and all([isinstance(c, Device) for c in devices]): devices = devices else: raise ValueError("devices must be a Device or a list of Device, " "for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], " "refer to mxnet.Device:{}".format(devices)) for device in devices: assert device in available_gpus or str(device).startswith('cpu'), \ "{} is not available, please make sure " \ "your device is in one of: mx.cpu(), {}".format( device, ', '.join([str(device) for device in available_gpus])) else: # provide default device if gpus > 0: # only use 1 GPU by default if gpus > 1: warnings.warn("You have multiple GPUs, gpu(0) will be used by default." "To utilize all your GPUs, specify device as a list of gpus, " "e.g. devices=[mx.gpu(0), mx.gpu(1)] ") devices = [gpu(0)] else: devices = [cpu()] return devices def _check_batch_processor(self, batch_processor): # check whether the batch processor contains fit_batch() and evaluate_batch() methods if batch_processor is not None: model_fit = getattr(batch_processor, 'fit_batch', None) model_evaluate = getattr(batch_processor, 'evaluate_batch', None) if not callable(model_fit) or not callable(model_evaluate): raise ValueError('Customized Batch Processor must contain fit_batch()' ' and evaluate_batch() methods') else: batch_processor = BatchProcessor() return batch_processor def _initialize(self, initializer): # initialize the network if not self._is_initialized(): # net is partially or not initialized, # initialize with user specified initializer # if initializer is None, default initializer will be used # do not re-init layers already initialized if initializer: self.net.initialize(init=initializer, device=self.device) else: self.net.initialize(device=self.device) elif initializer: # net is fully initialized, and user passed not None initializer # do not force reinitialize, give warning warnings.warn("Network already fully initialized, skipping initialization. " "You don't need to pass initializer if you already " "initialized your net. " "You can use net.initialize(init=your_initializer, force_reinit=True)" "to force re-initialize.") def _check_trainer(self, trainer): # handle trainer if not trainer: warnings.warn("No trainer specified, default SGD optimizer " "with learning rate 0.001 is used.") trainer = Trainer(self.net.collect_params(), 'sgd', {'learning_rate': 0.001}) elif not isinstance(trainer, Trainer): raise ValueError("Trainer must be a Gluon Trainer instance, refer to " "gluon.Trainer:{}".format(trainer)) return trainer def _is_initialized(self): param_dict = self.net.collect_params() for param in param_dict: try: param_dict[param].list_device() except RuntimeError: return False return True def _get_data_and_label(self, batch, device, batch_axis=0): data = batch[0] label = batch[1] data = split_and_load(data, device, batch_axis=batch_axis) label = split_and_load(label, device, batch_axis=batch_axis) return data, label def _add_default_training_metrics(self): if not self._train_metrics: suggested_metric = _suggest_metric_for_loss(self.loss) if suggested_metric: self._train_metrics = [suggested_metric] loss_name = type(self.loss).__name__ self._train_metrics.append(metric_loss(loss_name)) for metric in self._train_metrics: # add training prefix to the metric name # it is useful for event handlers to distinguish them from validation metrics metric.name = 'training ' + metric.name def _add_validation_metrics(self): if not self._val_metrics: self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics] for metric in self._val_metrics: # add validation prefix to the metric name # it is useful for event handlers to distinguish them from training metrics if 'training' in metric.name: metric.name = metric.name.replace('training', 'validation') else: metric.name = 'validation ' + metric.name @property def train_metrics(self): return self._train_metrics @property def val_metrics(self): return self._val_metrics
[docs] def evaluate(self, val_data, batch_axis=0, event_handlers=None): """Evaluate model on validation data. This function calls :py:func:`evaluate_batch` on each of the batches from the validation data loader. Thus, for custom use cases, it's possible to inherit the estimator class and override :py:func:`evaluate_batch`. Parameters ---------- val_data : DataLoader Validation data loader with data and labels. batch_axis : int, default 0 Batch axis to split the validation data into devices. event_handlers : EventHandler or list of EventHandler List of :py:class:`EventHandlers` to apply during validation. Besides event handlers specified here, a default MetricHandler and a LoggingHandler will be added if not specified explicitly. """ if not isinstance(val_data, DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " "can transform your DataIter or any NDArray into Gluon DataLoader. " "Refer to gluon.data.DataLoader") for metric in self.val_metrics: metric.reset() estimator_ref = self event_handlers = self._prepare_default_validation_handlers(event_handlers) _, epoch_begin, batch_begin, batch_end, \ epoch_end, _ = self._categorize_handlers(event_handlers) estimator_ref = self for handler in epoch_begin: handler.epoch_begin(estimator_ref) for _, batch in enumerate(val_data): for handler in batch_begin: handler.batch_begin(estimator_ref, batch=batch) _, label, pred, loss = \ self.batch_processor.evaluate_batch(estimator_ref, batch, batch_axis) for handler in batch_end: handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss) for handler in epoch_end: handler.epoch_end(estimator_ref)
[docs] def fit(self, train_data, val_data=None, epochs=None, event_handlers=None, batches=None, batch_axis=0): """Trains the model with a given :py:class:`DataLoader` for a specified number of epochs or batches. The batch size is inferred from the data loader's batch_size. This function calls :py:func:`fit_batch` on each of the batches from the training data loader. Thus, for custom use cases, it's possible to inherit the estimator class and override :py:func:`fit_batch`. Parameters ---------- train_data : DataLoader Training data loader with data and labels. val_data : DataLoader, default None Validation data loader with data and labels. epochs : int, default None Number of epochs to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). event_handlers : EventHandler or list of EventHandler List of :py:class:`EventHandlers` to apply during training. Besides the event handlers specified here, a StoppingHandler, LoggingHandler and MetricHandler will be added by default if not yet specified manually. If validation data is provided, a ValidationHandler is also added if not already specified. batches : int, default None Number of batches to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches). batch_axis : int, default 0 Batch axis to split the training data into devices. """ if not isinstance(train_data, DataLoader): raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " "can transform your DataIter or any NDArray into Gluon DataLoader. " "Refer to gluon.data.dataloader") # must specify one and only one of epochs or batches if (not epochs) == (not batches): raise ValueError( "Fit only support exactly one type of iteration, " "train by number of epochs or number of batches." "Please specify one and only one of: epochs or batches.") self.max_epoch = epochs self.max_batch = batches self.batch_axis = batch_axis # provide default handlers event_handlers = self._prepare_default_handlers(val_data, event_handlers) train_begin, epoch_begin, batch_begin, \ batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) # pass a reference to all event handlers estimator_ref = self # training begin for handler in train_begin: handler.train_begin(estimator_ref) while True: # epoch begin for handler in epoch_begin: handler.epoch_begin(estimator_ref) for batch in train_data: # batch begin for handler in batch_begin: handler.batch_begin(estimator_ref, batch=batch) _, label, pred, loss = self.batch_processor.fit_batch(estimator_ref, batch, batch_axis) # batch end batch_end_result = [] for handler in batch_end: batch_end_result.append(handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss)) # if any handler signaled to stop if any(batch_end_result): break # epoch end epoch_end_result = [] for handler in epoch_end: epoch_end_result.append(handler.epoch_end(estimator_ref)) # if any handler signaled to stop if any(epoch_end_result): break # train end for handler in train_end: handler.train_end(estimator_ref)
def _prepare_default_handlers(self, val_data, event_handlers): event_handlers = _check_event_handlers(event_handlers) added_default_handlers = [] # no need to add to default handler check as StoppingHandler does not use metrics added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) if not any(isinstance(handler, GradientUpdateHandler) for handler in event_handlers): added_default_handlers.append(GradientUpdateHandler()) if not any(isinstance(handler, MetricHandler) for handler in event_handlers): added_default_handlers.append(MetricHandler(metrics=self.train_metrics)) if not any(isinstance(handler, ValidationHandler) for handler in event_handlers): # no validation handler if val_data: # add default validation handler if validation data found added_default_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate)) if not any(isinstance(handler, LoggingHandler) for handler in event_handlers): added_default_handlers.append(LoggingHandler(metrics=self.train_metrics)) # if there is a mix of user defined event handlers and default event handlers # they should have the same set of metrics mixing_handlers = event_handlers and added_default_handlers event_handlers.extend(added_default_handlers) if mixing_handlers: # check if all handlers have the same set of references to metrics known_metrics = set(self.train_metrics + self.val_metrics) for handler in event_handlers: _check_handler_metric_ref(handler, known_metrics) event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0)) return event_handlers def _prepare_default_validation_handlers(self, event_handlers): event_handlers = _check_event_handlers(event_handlers) added_default_handlers = [] # add default logging handler and metric handler for validation if not any(isinstance(handler, MetricHandler) for handler in event_handlers): added_default_handlers.append(MetricHandler(metrics=self.val_metrics)) if not any(isinstance(handler, LoggingHandler) for handler in event_handlers): added_default_handlers.append(LoggingHandler(metrics=self.val_metrics)) mixing_handlers = event_handlers and added_default_handlers event_handlers.extend(added_default_handlers) # check if all handlers refer to well-defined validation metrics if mixing_handlers: known_metrics = set(self.val_metrics) for handler in event_handlers: _check_handler_metric_ref(handler, known_metrics) event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0)) return event_handlers def _categorize_handlers(self, event_handlers): """ categorize handlers into 6 event lists to avoid calling empty methods for example, only event handlers with train_begin method implemented will be called at train begin """ train_begin = [] epoch_begin = [] batch_begin = [] batch_end = [] epoch_end = [] train_end = [] for handler in event_handlers: if isinstance(handler, TrainBegin): train_begin.append(handler) if isinstance(handler, EpochBegin): epoch_begin.append(handler) if isinstance(handler, BatchBegin): batch_begin.append(handler) if isinstance(handler, BatchEnd): batch_end.append(handler) if isinstance(handler, EpochEnd): epoch_end.append(handler) if isinstance(handler, TrainEnd): train_end.append(handler) return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end