mxnet.gluon.contrib.estimator.event_handler¶
Gluon EventHandlers for Estimators
Classes
|
|
|
|
|
Save the model after user define period |
|
Early stop training if monitored value is not improving |
|
|
|
|
|
|
|
Gradient Update Handler that apply gradients on network weights |
|
Basic Logging Handler that applies to every Gluon estimator by default. |
|
Metric Handler that update metric values at batch end |
|
Stop conditions to stop training Stop training if maximum number of batches or epochs reached. |
|
|
|
|
|
Validation Handler that evaluate model on validation dataset |
- class mxnet.gluon.contrib.estimator.event_handler.CheckpointHandler(model_dir, model_prefix='model', monitor=None, verbose=0, save_best=False, mode='auto', epoch_period=1, batch_period=None, max_checkpoints=5, resume_from_checkpoint=False)[source]¶
Bases:
TrainBegin,BatchEnd,EpochEndSave the model after user define period
CheckpointHandlersaves the network architecture after first batch if the model can be fully hybridized, saves model parameters and trainer states after user defined period, default saves every epoch.- Parameters:
model_dir (str) – File directory to save all the model related files including model architecture, model parameters, and trainer states.
model_prefix (str default 'model') – Prefix to add for all checkpoint file names.
monitor (EvalMetric, default None) – The metrics to monitor and determine if model has improved
verbose (int, default 0) – Verbosity mode, 1 means inform user every time a checkpoint is saved
save_best (bool, default False) – If True, monitor must not be None,
CheckpointHandlerwill save the model parameters and trainer states with the best monitored value.mode (str, default 'auto') – One of {auto, min, max}, if save_best=True, the comparison to make and determine if the monitored value has improved. if ‘auto’ mode,
CheckpointHandlerwill try to use min or max based on the monitored metric name.epoch_period (int, default 1) – Epoch intervals between saving the network. By default, checkpoints are saved every epoch.
batch_period (int, default None) – Batch intervals between saving the network. By default, checkpoints are not saved based on the number of batches.
max_checkpoints (int, default 5) – Maximum number of checkpoint files to keep in the model_dir, older checkpoints will be removed. Best checkpoint file is not counted.
resume_from_checkpoint (bool, default False) – Whether to resume training from checkpoint in model_dir. If True and checkpoints found,
CheckpointHandlerwill load net parameters and trainer states, and train the remaining of epochs and batches.
- class mxnet.gluon.contrib.estimator.event_handler.EarlyStoppingHandler(monitor, min_delta=0, patience=0, mode='auto', baseline=None)[source]¶
Bases:
TrainBegin,EpochEnd,TrainEndEarly stop training if monitored value is not improving
- Parameters:
monitor (EvalMetric) – The metric to monitor, and stop training if this metric does not improve.
min_delta (float, default 0) – Minimal change in monitored value to be considered as an improvement.
patience (int, default 0) – Number of epochs to wait for improvement before terminate training.
mode (str, default 'auto') – One of {auto, min, max}, if save_best_only=True, the comparison to make and determine if the monitored value has improved. if ‘auto’ mode, checkpoint handler will try to use min or max based on the monitored metric name.
baseline (float) – Baseline value to compare the monitored value with.
- class mxnet.gluon.contrib.estimator.event_handler.GradientUpdateHandler(priority=-2000)[source]¶
Bases:
BatchEndGradient Update Handler that apply gradients on network weights
GradientUpdateHandlertakes the priority level. It updates weight parameters at the end of each batch- Parameters:
priority (scalar, default -2000) – priority level of the gradient update handler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is.
- class mxnet.gluon.contrib.estimator.event_handler.LoggingHandler(log_interval='epoch', metrics=None, priority=inf)[source]¶
Bases:
TrainBegin,TrainEnd,EpochBegin,EpochEnd,BatchBegin,BatchEndBasic Logging Handler that applies to every Gluon estimator by default.
LoggingHandlerlogs hyper-parameters, training statistics, and other useful information during training- Parameters:
log_interval (int or str, default 'epoch') – Logging interval during training. log_interval=’epoch’: display metrics every epoch log_interval=integer k: display metrics every interval of k batches
metrics (list of EvalMetrics) – Metrics to be logged, logged at batch end, epoch end, train end.
priority (scalar, default np.inf) – Priority level of the LoggingHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is.
- class mxnet.gluon.contrib.estimator.event_handler.MetricHandler(metrics, priority=-1000)[source]¶
Bases:
EpochBegin,BatchEndMetric Handler that update metric values at batch end
MetricHandlertakes model predictions and true labels and update the metrics, it also update metric wrapper for loss with loss values. Validation loss and metrics will be handled byValidationHandler- Parameters:
metrics (List of EvalMetrics) – Metrics to be updated at batch end.
priority (scalar) – Priority level of the MetricHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is.
- class mxnet.gluon.contrib.estimator.event_handler.StoppingHandler(max_epoch=None, max_batch=None)[source]¶
Bases:
TrainBegin,BatchEnd,EpochEndStop conditions to stop training Stop training if maximum number of batches or epochs reached.
- class mxnet.gluon.contrib.estimator.event_handler.ValidationHandler(val_data, eval_fn, epoch_period=1, batch_period=None, priority=-1000, event_handlers=None)[source]¶
Bases:
TrainBegin,BatchEnd,EpochEndValidation Handler that evaluate model on validation dataset
ValidationHandlertakes validation dataset, an evaluation function, metrics to be evaluated, and how often to run the validation. You can provide custom evaluation function or use the one provided myEstimator- Parameters:
val_data (DataLoader) – Validation data set to run evaluation.
eval_fn (function) – A function defines how to run evaluation and calculate loss and metrics.
epoch_period (int, default 1) – How often to run validation at epoch end, by default
ValidationHandlervalidate every epoch.batch_period (int, default None) – How often to run validation at batch end, by default
ValidationHandlerdoes not validate at batch end.priority (scalar, default -1000) – Priority level of the ValidationHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is.
event_handlers (EventHandler or list of EventHandlers) – List of
EventHandlerto apply during validaiton. This argument is used by self.eval_fn function in order to process customized event handlers.