mxnet.gluon.contrib.estimator.batch_processor

Gluon Batch Processor for Estimators

Classes

BatchProcessor()

BatchProcessor Class for plug and play fit_batch & evaluate_batch

class mxnet.gluon.contrib.estimator.batch_processor.BatchProcessor[source]

Bases: object

BatchProcessor Class for plug and play fit_batch & evaluate_batch

During training or validation, data are divided into minibatches for processing. This class aims at providing hooks of training or validating on a minibatch of data. Users may provide customized fit_batch() and evaluate_batch() methods by inheriting from this class and overriding class methods.

BatchProcessor can be used to replace fit_batch() and evaluate_batch() in the base estimator class

evaluate_batch(estimator, val_batch, batch_axis=0)[source]

Evaluate the estimator model on a batch of validation data.

Parameters:
  • estimator (Estimator) – Reference to the estimator

  • val_batch (tuple) – Data and label of a batch from the validation data loader.

  • batch_axis (int, default 0) – Batch axis to split the validation data into devices.

fit_batch(estimator, train_batch, batch_axis=0)[source]

Trains the estimator model on a batch of training data.

Parameters:
  • estimator (Estimator) – Reference to the estimator

  • train_batch (tuple) – Data and label of a batch from the training data loader.

  • batch_axis (int, default 0) – Batch axis to split the training data into devices.

Returns:

  • data (List of NDArray) – Sharded data from the batch. Data is sharded with gluon.split_and_load.

  • label (List of NDArray) – Sharded label from the batch. Labels are sharded with gluon.split_and_load.

  • pred (List of NDArray) – Prediction on each of the sharded inputs.

  • loss (List of NDArray) – Loss on each of the sharded inputs.