Source code for mxnet.gluon.contrib.estimator.event_handler

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
"""Gluon EventHandlers for Estimators"""

import os
import time
import warnings

import numpy as np

from ...metric import CompositeEvalMetric, EvalMetric
from ...metric import Loss as metric_loss
from .utils import _check_metrics

__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
           'StoppingHandler', 'MetricHandler', 'ValidationHandler',
           'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler', 'GradientUpdateHandler']


class EventHandler(object):
    pass


def _check_event_handlers(handlers):
    if isinstance(handlers, EventHandler):
        handlers = [handlers]
    else:
        handlers = handlers or []
        if not all([isinstance(handler, EventHandler) for handler in handlers]):
            raise ValueError("handlers must be an EventHandler or a list of EventHandler, "
                             "got: {}".format(handlers))
    return handlers


class TrainBegin(EventHandler):
    def train_begin(self, estimator, *args, **kwargs):
        pass


class TrainEnd(EventHandler):
    def train_end(self, estimator, *args, **kwargs):
        pass


class EpochBegin(EventHandler):
    def epoch_begin(self, estimator, *args, **kwargs):
        pass


class EpochEnd(EventHandler):
    def epoch_end(self, estimator, *args, **kwargs):
        return False


class BatchBegin(EventHandler):
    def batch_begin(self, estimator, *args, **kwargs):
        pass


class BatchEnd(EventHandler):
    def batch_end(self, estimator, *args, **kwargs):
        return False


[docs]class StoppingHandler(TrainBegin, BatchEnd, EpochEnd): """Stop conditions to stop training Stop training if maximum number of batches or epochs reached. Parameters ---------- max_epoch : int, default None Number of maximum epochs to train. max_batch : int, default None Number of maximum batches to train. """ def __init__(self, max_epoch=None, max_batch=None): self.max_epoch = max_epoch self.max_batch = max_batch self.current_batch = 0 self.current_epoch = 0 self.stop_training = False def train_begin(self, estimator, *args, **kwargs): self.max_epoch = estimator.max_epoch self.max_batch = estimator.max_batch self.current_batch = 0 self.current_epoch = 0 def batch_end(self, estimator, *args, **kwargs): self.current_batch += 1 if self.current_batch == self.max_batch: self.stop_training = True return self.stop_training def epoch_end(self, estimator, *args, **kwargs): self.current_epoch += 1 if self.current_epoch == self.max_epoch: self.stop_training = True return self.stop_training
[docs]class MetricHandler(EpochBegin, BatchEnd): """Metric Handler that update metric values at batch end :py:class:`MetricHandler` takes model predictions and true labels and update the metrics, it also update metric wrapper for loss with loss values. Validation loss and metrics will be handled by :py:class:`ValidationHandler` Parameters ---------- metrics : List of EvalMetrics Metrics to be updated at batch end. priority : scalar Priority level of the MetricHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is. """ def __init__(self, metrics, priority=-1000): self.metrics = _check_metrics(metrics) # order to be called among all callbacks # metrics need to be calculated before other callbacks can access them self.priority = priority def epoch_begin(self, estimator, *args, **kwargs): for metric in self.metrics: metric.reset() def batch_end(self, estimator, *args, **kwargs): pred = kwargs['pred'] label = kwargs['label'] loss = kwargs['loss'] for metric in self.metrics: if isinstance(metric, metric_loss): # metric wrapper for loss values metric.update(0, loss) else: metric.update(label, pred)
[docs]class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): """Validation Handler that evaluate model on validation dataset :py:class:`ValidationHandler` takes validation dataset, an evaluation function, metrics to be evaluated, and how often to run the validation. You can provide custom evaluation function or use the one provided my :py:class:`Estimator` Parameters ---------- val_data : DataLoader Validation data set to run evaluation. eval_fn : function A function defines how to run evaluation and calculate loss and metrics. epoch_period : int, default 1 How often to run validation at epoch end, by default :py:class:`ValidationHandler` validate every epoch. batch_period : int, default None How often to run validation at batch end, by default :py:class:`ValidationHandler` does not validate at batch end. priority: scalar, default -1000 Priority level of the ValidationHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is. event_handlers : EventHandler or list of EventHandlers List of :py:class:`EventHandler` to apply during validaiton. This argument is used by self.eval_fn function in order to process customized event handlers. """ def __init__(self, val_data, eval_fn, epoch_period=1, batch_period=None, priority=-1000, event_handlers=None): self.val_data = val_data self.eval_fn = eval_fn self.epoch_period = epoch_period self.batch_period = batch_period self.current_batch = 0 self.current_epoch = 0 # order to be called among all callbacks # validation metrics need to be calculated before other callbacks can access them self.priority = priority self.event_handlers = event_handlers def train_begin(self, estimator, *args, **kwargs): # reset epoch and batch counter self.current_batch = 0 self.current_epoch = 0 def batch_end(self, estimator, *args, **kwargs): self.current_batch += 1 if self.batch_period and self.current_batch % self.batch_period == 0: self.eval_fn(val_data=self.val_data, batch_axis=estimator.batch_axis, event_handlers=self.event_handlers) def epoch_end(self, estimator, *args, **kwargs): self.current_epoch += 1 if self.epoch_period and self.current_epoch % self.epoch_period == 0: self.eval_fn(val_data=self.val_data, batch_axis=estimator.batch_axis, event_handlers=self.event_handlers)
[docs]class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd): """Basic Logging Handler that applies to every Gluon estimator by default. :py:class:`LoggingHandler` logs hyper-parameters, training statistics, and other useful information during training Parameters ---------- log_interval: int or str, default 'epoch' Logging interval during training. log_interval='epoch': display metrics every epoch log_interval=integer k: display metrics every interval of k batches metrics : list of EvalMetrics Metrics to be logged, logged at batch end, epoch end, train end. priority : scalar, default np.Inf Priority level of the LoggingHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is. """ def __init__(self, log_interval='epoch', metrics=None, priority=np.Inf): super(LoggingHandler, self).__init__() if not isinstance(log_interval, int) and log_interval != 'epoch': raise ValueError("log_interval must be either an integer or string 'epoch'") self.metrics = _check_metrics(metrics) self.batch_index = 0 self.current_epoch = 0 self.processed_samples = 0 # logging handler need to be called at last to make sure all states are updated # it will also shut down logging at train end self.priority = priority self.log_interval = log_interval self.log_interval_time = 0 def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() trainer = estimator.trainer optimizer = trainer.optimizer.__class__.__name__ lr = trainer.learning_rate estimator.logger.info("Training begin: using optimizer %s " "with current learning rate %.4f ", optimizer, lr) if estimator.max_epoch: estimator.logger.info("Train for %d epochs.", estimator.max_epoch) else: estimator.logger.info("Train for %d batches.", estimator.max_batch) # reset all counters self.current_epoch = 0 self.batch_index = 0 self.processed_samples = 0 self.log_interval_time = 0 def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start msg = f'Train finished using total {train_time}s with {self.current_epoch} epochs. ' # log every result in train stats including train/validation loss & metrics for metric in self.metrics: name, value = metric.get() msg += f'{name}: {value:.4f}, ' estimator.logger.info(msg.rstrip(', ')) def batch_begin(self, estimator, *args, **kwargs): if isinstance(self.log_interval, int): self.batch_start = time.time() def batch_end(self, estimator, *args, **kwargs): if isinstance(self.log_interval, int): batch_time = time.time() - self.batch_start msg = f'[Epoch {self.current_epoch}][Batch {self.batch_index}]' self.processed_samples += kwargs['batch'][0].shape[0] msg += f'[Samples {self.processed_samples}] ' self.log_interval_time += batch_time if self.batch_index % self.log_interval == 0: msg += f'time/interval: {self.log_interval_time:.3f}s ' self.log_interval_time = 0 for metric in self.metrics: # only log current training loss & metric after each interval name, value = metric.get() msg += f'{name}: {value:.4f}, ' estimator.logger.info(msg.rstrip(', ')) self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): if isinstance(self.log_interval, int) or self.log_interval == 'epoch': is_training = False # use the name hack defined in __init__() of estimator class for metric in self.metrics: if 'training' in metric.name: is_training = True self.epoch_start = time.time() if is_training: estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", self.current_epoch, estimator.trainer.learning_rate) else: estimator.logger.info("Validation Begin") def epoch_end(self, estimator, *args, **kwargs): if isinstance(self.log_interval, int) or self.log_interval == 'epoch': epoch_time = time.time() - self.epoch_start msg = f'[Epoch {self.current_epoch}] Finished in {epoch_time:.3f}s, ' for monitor in self.metrics: name, value = monitor.get() msg += f'{name}: {value:.4f}, ' estimator.logger.info(msg.rstrip(', ')) self.current_epoch += 1 self.batch_index = 0
[docs]class CheckpointHandler(TrainBegin, BatchEnd, EpochEnd): """Save the model after user define period :py:class:`CheckpointHandler` saves the network architecture after first batch if the model can be fully hybridized, saves model parameters and trainer states after user defined period, default saves every epoch. Parameters ---------- model_dir : str File directory to save all the model related files including model architecture, model parameters, and trainer states. model_prefix : str default 'model' Prefix to add for all checkpoint file names. monitor: EvalMetric, default None The metrics to monitor and determine if model has improved verbose: int, default 0 Verbosity mode, 1 means inform user every time a checkpoint is saved save_best: bool, default False If True, monitor must not be None, :py:class:`CheckpointHandler` will save the model parameters and trainer states with the best monitored value. mode: str, default 'auto' One of {auto, min, max}, if `save_best=True`, the comparison to make and determine if the monitored value has improved. if 'auto' mode, :py:class:`CheckpointHandler` will try to use min or max based on the monitored metric name. epoch_period: int, default 1 Epoch intervals between saving the network. By default, checkpoints are saved every epoch. batch_period: int, default None Batch intervals between saving the network. By default, checkpoints are not saved based on the number of batches. max_checkpoints : int, default 5 Maximum number of checkpoint files to keep in the model_dir, older checkpoints will be removed. Best checkpoint file is not counted. resume_from_checkpoint : bool, default False Whether to resume training from checkpoint in model_dir. If True and checkpoints found, :py:class:`CheckpointHandler` will load net parameters and trainer states, and train the remaining of epochs and batches. """ def __init__(self, model_dir, model_prefix='model', monitor=None, verbose=0, save_best=False, mode='auto', epoch_period=1, batch_period=None, max_checkpoints=5, resume_from_checkpoint=False): self.monitor = monitor self.verbose = verbose if not os.path.exists(model_dir): os.makedirs(model_dir) self.model_dir = model_dir self.model_prefix = model_prefix self.save_best = save_best if self.save_best and not isinstance(self.monitor, EvalMetric): raise ValueError("To save best model only, please provide one of the metric objects " "from estimator.train_metrics and estimator.val_metrics as monitor.") self.epoch_period = epoch_period self.batch_period = batch_period self.current_batch = 0 self.current_epoch = 0 self.max_checkpoints = max_checkpoints self.resume_from_checkpoint = resume_from_checkpoint self.saved_checkpoints = [] if self.save_best: if mode not in ['auto', 'min', 'max']: warnings.warn(f'ModelCheckpoint mode {mode} is unknown, ' 'fallback to auto mode. CheckpointHandler will use' 'max mode for f1 and accuracy metric comparison and ' 'use min mode other wise', RuntimeWarning) mode = 'auto' if mode == 'min': self.monitor_op = np.less self.best = np.Inf elif mode == 'max': self.monitor_op = np.greater self.best = -np.Inf else: # use greater for accuracy and f1 and less otherwise if 'acc' or 'f1' in self.monitor.get()[0].lower(): warnings.warn("`greater` operator will be used to determine if {} has improved. " "Please specify `mode='min'` to use the `less` operator. " "Specify `mode='max' to disable this warning.`" .format(self.monitor.get()[0])) self.monitor_op = np.greater else: warnings.warn("`less` operator will be used to determine if {} has improved. " "Please specify `mode='max'` to use the `greater` operator. " "Specify `mode='min' to disable this warning.`" .format(self.monitor.get()[0])) self.monitor_op = np.less def train_begin(self, estimator, *args, **kwargs): # reset all counters self.current_epoch = 0 self.current_batch = 0 if self.save_best: self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable if self.resume_from_checkpoint: error_msg = "To use resume from checkpoint, you must only specify " \ "the same type of period you used for training." \ "For example, if you are training based on number of epochs," \ "you must save only based on epochs, and set batch_period to None." if estimator.max_batch: assert self.batch_period, error_msg assert not self.epoch_period, error_msg if estimator.max_epoch: assert self.epoch_period, error_msg assert not self.batch_period, error_msg self._resume_from_checkpoint(estimator) def batch_end(self, estimator, *args, **kwargs): # only save symbol once after first batch if self.current_batch == 0: self._save_symbol(estimator) if self.batch_period and (self.current_batch + 1) % self.batch_period == 0: self._save_checkpoint(estimator) self.current_batch += 1 def epoch_end(self, estimator, *args, **kwargs): if self.epoch_period and (self.current_epoch + 1) % self.epoch_period == 0: self._save_checkpoint(estimator) self.current_epoch += 1 def _save_checkpoint(self, estimator): # if resumed from checkpoint, increment checkpoint number if self.resume_from_checkpoint: save_epoch_number = self.current_epoch + self.trained_epoch + 1 if estimator.max_epoch: # checkpoint saved at epoch end, batch number already incremented save_batch_number = self.current_batch + self.trained_batch else: save_batch_number = self.current_batch + self.trained_batch + 1 else: save_epoch_number = self.current_epoch save_batch_number = self.current_batch prefix = f"{self.model_prefix}-epoch{save_epoch_number}batch{save_batch_number}" self._save_params_and_trainer(estimator, prefix) if self.verbose > 0: estimator.logger.info(f'[Epoch {self.current_epoch}] CheckpointHandler: trained total {self.current_batch + 1} batches, ' f'saving model at {self.model_dir} with prefix: {prefix}') if self.save_best: monitor_name, monitor_value = self.monitor.get() # check if monitor exists in train stats if np.isnan(monitor_value): warnings.warn(RuntimeWarning( 'Skipping save best because %s is not updated, make sure you pass one of the ' 'metric objects estimator.train_metrics and estimator.val_metrics as monitor', monitor_name)) else: if self.monitor_op(monitor_value, self.best): prefix = self.model_prefix + '-best' self._save_params_and_trainer(estimator, prefix) if self.verbose > 0: estimator.logger.info('[Epoch %d] CheckpointHandler: ' '%s improved from %0.5f to %0.5f, ' 'updating best model at %s with prefix: %s', self.current_epoch, monitor_name, self.best, monitor_value, self.model_dir, prefix) self.best = monitor_value else: if self.verbose > 0: estimator.logger.info('[Epoch %d] CheckpointHandler: ' '%s did not improve from %0.5f, ' 'skipping updating best model', self.current_batch, monitor_name, self.best) def _save_symbol(self, estimator): symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json') if hasattr(estimator.net, '_cached_graph') and estimator.net._cached_graph: sym = estimator.net._cached_graph[1] sym.save(symbol_file) else: estimator.logger.info( "Model architecture(symbol file) is not saved, please use HybridBlock " "to construct your model, and call net.hybridize() before passing to " "Estimator in order to save model architecture as %s.", symbol_file) def _save_params_and_trainer(self, estimator, file_prefix): param_file = os.path.join(self.model_dir, file_prefix + '.params') trainer_file = os.path.join(self.model_dir, file_prefix + '.states') estimator.net.save_parameters(param_file) estimator.trainer.save_states(trainer_file) # only count checkpoints with epoch or batch number in file name if 'best' not in file_prefix: self.saved_checkpoints.append(file_prefix) # remove old checkpoint when max number of checkpoints reached if len(self.saved_checkpoints) > self.max_checkpoints: prefix = self.saved_checkpoints.pop(0) for fname in os.listdir(self.model_dir): if fname.startswith(prefix): os.remove(os.path.join(self.model_dir, fname)) def _resume_from_checkpoint(self, estimator): prefix = self.model_prefix + '-epoch' self.trained_epoch = self._find_max_iteration( dir=self.model_dir, prefix=prefix, start='epoch', end='batch', saved_checkpoints=self.saved_checkpoints) prefix += str(self.trained_epoch) self.trained_batch = self._find_max_iteration( dir=self.model_dir, prefix=prefix, start='batch', end='.params') if self.trained_epoch == -1: msg = "CheckpointHandler: No checkpoint found, training from scratch for " if estimator.max_batch: msg += f"{estimator.max_batch} batches" else: msg += f"{estimator.max_epoch} epochs" estimator.logger.info(msg) else: msg = f"CheckpointHandler: Checkpoint resumed from epoch {self.trained_epoch} batch {self.trained_batch}, " \ "continue to train for " # change maximum number of epoch or batch to train if resumed from epoch checkpoint if estimator.max_epoch: if self.trained_epoch >= estimator.max_epoch - 1: raise ValueError(f"Found checkpoint with maximum number of epoch {estimator.max_epoch} reached, please specify " "resume_from_checkpoint=False (default value) if you wan to train from scratch.") estimator.max_epoch = estimator.max_epoch - self.trained_epoch - 1 msg += f"{estimator.max_epoch} epochs " if estimator.max_batch: if self.trained_batch >= estimator.max_batch - 1: raise ValueError(f"Found checkpoint with maximum number of batch {self.trained_batch} reached, please specify" "resume_from_checkpoint=False (default value) if you wan to train from scratch.") estimator.max_batch = estimator.max_batch - self.trained_batch - 1 msg += f"{estimator.max_batch} batches " # load checkpoint param_file = "{}-epoch{}batch{}.params".format(self.model_prefix, self.trained_epoch, self.trained_batch) param_file = os.path.join(self.model_dir, param_file) trainer_file = "{}-epoch{}batch{}.states".format(self.model_prefix, self.trained_epoch, self.trained_batch) trainer_file = os.path.join(self.model_dir, trainer_file) assert os.path.exists(param_file), f"Failed to load checkpoint, {param_file} does not exist" assert os.path.exists(trainer_file), f"Failed to load checkpoint, {trainer_file} does not exist" estimator.net.load_parameters(param_file, ctx=estimator.device) estimator.trainer.load_states(trainer_file) estimator.logger.warning(msg) def _find_max_iteration(self, dir, prefix, start, end, saved_checkpoints=None): error_msg = "Error parsing checkpoint file, please check your " \ "checkpoints have the format: " \ "{model_name}-epoch{epoch_number}batch{batch_number}.params, " \ "there should also be a .states file for each .params file " max_iter = -1 for fname in os.listdir(dir): if fname.startswith(prefix) and '.params' in fname: if saved_checkpoints: # save prefix of existing checkpoints saved_checkpoints.append(fname[:fname.find('.params')]) try: # find trained number of epoch iter = int(fname[fname.find(start) + len(start): fname.find(end)]) if iter > max_iter: max_iter = iter except ValueError: raise ValueError(error_msg) return max_iter
[docs]class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd): """Early stop training if monitored value is not improving Parameters ---------- monitor: EvalMetric The metric to monitor, and stop training if this metric does not improve. min_delta: float, default 0 Minimal change in monitored value to be considered as an improvement. patience: int, default 0 Number of epochs to wait for improvement before terminate training. mode: str, default 'auto' One of {auto, min, max}, if `save_best_only=True`, the comparison to make and determine if the monitored value has improved. if 'auto' mode, checkpoint handler will try to use min or max based on the monitored metric name. baseline: float Baseline value to compare the monitored value with. """ def __init__(self, monitor, min_delta=0, patience=0, mode='auto', baseline=None): super(EarlyStoppingHandler, self).__init__() if not isinstance(monitor, EvalMetric): raise ValueError( "Please provide one of the metric objects from estimator.train_metrics and " "estimator.val_metrics as monitor.") if isinstance(monitor, CompositeEvalMetric): raise ValueError("CompositeEvalMetric is not supported for EarlyStoppingHandler, " "please specify a simple metric instead.") self.monitor = monitor self.baseline = baseline self.patience = patience self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 self.current_epoch = 0 self.stop_training = False if mode not in ['auto', 'min', 'max']: warnings.warn(f'EarlyStopping mode {mode} is unknown, ' 'fallback to auto mode. CheckpointHandler will use' 'max mode for f1 and accuracy metric comparison and ' 'use min mode other wise', RuntimeWarning) mode = 'auto' if mode == 'min': self.monitor_op = np.less elif mode == 'max': self.monitor_op = np.greater else: if 'acc' or 'f1' in self.monitor.get()[0].lower(): warnings.warn("`greater` operator will be used to determine if {} has improved. " "Please specify `mode='min'` to use the `less` operator. " "Specify `mode='max' to disable this warning.`" .format(self.monitor.get()[0])) self.monitor_op = np.greater else: warnings.warn("`less` operator will be used to determine if {} has improved. " "Please specify `mode='max'` to use the `greater` operator. " "Specify `mode='min' to disable this warning.`" .format(self.monitor.get()[0])) self.monitor_op = np.less if self.monitor_op == np.greater: # pylint: disable=comparison-with-callable self.min_delta *= 1 else: self.min_delta *= -1 def train_begin(self, estimator, *args, **kwargs): self.wait = 0 self.stopped_epoch = 0 self.current_epoch = 0 self.stop_training = False if self.baseline is not None: self.best = self.baseline else: self.best = np.Inf if self.monitor_op == np.less else -np.Inf # pylint: disable=comparison-with-callable def epoch_end(self, estimator, *args, **kwargs): monitor_name, monitor_value = self.monitor.get() if np.isnan(monitor_value): warnings.warn(RuntimeWarning( '%s is not updated, make sure you pass one of the metric objects from' 'estimator.train_metrics and estimator.val_metrics as monitor.', monitor_name)) else: if self.monitor_op(monitor_value - self.min_delta, self.best): self.best = monitor_value self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = self.current_epoch self.stop_training = True self.current_epoch += 1 return self.stop_training def train_end(self, estimator, *args, **kwargs): if self.stopped_epoch > 0: estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: ' 'early stopping due to %s not improving', self.stopped_epoch, self.monitor.get()[0])
[docs]class GradientUpdateHandler(BatchEnd): """Gradient Update Handler that apply gradients on network weights :py:class:`GradientUpdateHandler` takes the priority level. It updates weight parameters at the end of each batch Parameters ---------- priority : scalar, default -2000 priority level of the gradient update handler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is. """ def __init__(self, priority=-2000): self.priority = priority def batch_end(self, estimator, *args, **kwargs): loss = kwargs['loss'] batch_size = 0 if not isinstance(loss, list): loss = [loss] if isinstance(loss, list): for l in loss: batch_size += l.shape[0] estimator.trainer.step(batch_size)