mxnet.gluon.contrib.estimator.estimator

Gluon Estimator

Classes

Estimator(net, loss[, train_metrics, ...])

Estimator Class for easy model training

class mxnet.gluon.contrib.estimator.estimator.Estimator(net, loss, train_metrics=None, val_metrics=None, initializer=None, trainer=None, device=None, val_net=None, val_loss=None, batch_processor=None)[source]

Bases: object

Estimator Class for easy model training

Estimator can be used to facilitate the training & validation process

Parameters:
  • net (gluon.Block) – The model used for training.

  • loss (gluon.loss.Loss) – Loss (objective) function to calculate during training.

  • train_metrics (EvalMetric or list of EvalMetric) – Training metrics for evaluating models on training dataset.

  • val_metrics (EvalMetric or list of EvalMetric) – Validation metrics for evaluating models on validation dataset.

  • initializer (Initializer) – Initializer to initialize the network.

  • trainer (Trainer) – Trainer to apply optimizer on network parameters.

  • device (Device or list of Device) – Device(s) to run the training on.

  • val_net (gluon.Block) –

    The model used for validation. The validation model does not necessarily belong to the same model class as the training model. But the two models typically share the same architecture. Therefore the validation model can reuse parameters of the training model.

    The code example of consruction of val_net sharing the same network parameters as the training net is given below:

    >>> net = _get_train_network()
    >>> val_net = _get_test_network()
    >>> val_net.share_parameters(net.collect_params())
    >>> net.initialize(device=device)
    >>> est = Estimator(net, loss, val_net=val_net)
    

    Proper namespace match is required for weight sharing between two networks. Most networks inheriting Block can share their parameters correctly. An exception is Sequential networks that Block scope must be specified for correct weight sharing. For the naming in mxnet Gluon API, please refer to the site (https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html) for future information.

  • val_loss (gluon.loss.loss) – Loss (objective) function to calculate during validation. If set val_loss None, it will use the same loss function as self.loss

  • batch_processor (BatchProcessor) – BatchProcessor provides customized fit_batch() and evaluate_batch() methods

evaluate(val_data, batch_axis=0, event_handlers=None)[source]

Evaluate model on validation data.

This function calls evaluate_batch() on each of the batches from the validation data loader. Thus, for custom use cases, it’s possible to inherit the estimator class and override evaluate_batch().

Parameters:
  • val_data (DataLoader) – Validation data loader with data and labels.

  • batch_axis (int, default 0) – Batch axis to split the validation data into devices.

  • event_handlers (EventHandler or list of EventHandler) – List of EventHandlers to apply during validation. Besides event handlers specified here, a default MetricHandler and a LoggingHandler will be added if not specified explicitly.

fit(train_data, val_data=None, epochs=None, event_handlers=None, batches=None, batch_axis=0)[source]

Trains the model with a given DataLoader for a specified number of epochs or batches. The batch size is inferred from the data loader’s batch_size.

This function calls fit_batch() on each of the batches from the training data loader. Thus, for custom use cases, it’s possible to inherit the estimator class and override fit_batch().

Parameters:
  • train_data (DataLoader) – Training data loader with data and labels.

  • val_data (DataLoader, default None) – Validation data loader with data and labels.

  • epochs (int, default None) – Number of epochs to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches).

  • event_handlers (EventHandler or list of EventHandler) – List of EventHandlers to apply during training. Besides the event handlers specified here, a StoppingHandler, LoggingHandler and MetricHandler will be added by default if not yet specified manually. If validation data is provided, a ValidationHandler is also added if not already specified.

  • batches (int, default None) – Number of batches to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches).

  • batch_axis (int, default 0) – Batch axis to split the training data into devices.

logger = None

logging.Logger object associated with the Estimator.

The logger is used for all logs generated by this estimator and its handlers. A new logging.Logger is created during Estimator construction and configured to write all logs with level logging.INFO or higher to sys.stdout.

You can modify the logging settings using the standard Python methods. For example, to save logs to a file in addition to printing them to stdout output, you can attach a logging.FileHandler to the logger.

>>> est = Estimator(net, loss)
>>> import logging
>>> est.logger.addHandler(logging.FileHandler(filename))