mxnet.gluon.contrib.estimator.estimator¶
Gluon Estimator
Classes
|
Estimator Class for easy model training |
- class mxnet.gluon.contrib.estimator.estimator.Estimator(net, loss, train_metrics=None, val_metrics=None, initializer=None, trainer=None, device=None, val_net=None, val_loss=None, batch_processor=None)[source]¶
Bases:
objectEstimator Class for easy model training
Estimatorcan be used to facilitate the training & validation process- Parameters:
net (gluon.Block) – The model used for training.
loss (gluon.loss.Loss) – Loss (objective) function to calculate during training.
train_metrics (EvalMetric or list of EvalMetric) – Training metrics for evaluating models on training dataset.
val_metrics (EvalMetric or list of EvalMetric) – Validation metrics for evaluating models on validation dataset.
initializer (Initializer) – Initializer to initialize the network.
trainer (Trainer) – Trainer to apply optimizer on network parameters.
device (Device or list of Device) – Device(s) to run the training on.
val_net (gluon.Block) –
The model used for validation. The validation model does not necessarily belong to the same model class as the training model. But the two models typically share the same architecture. Therefore the validation model can reuse parameters of the training model.
The code example of consruction of val_net sharing the same network parameters as the training net is given below:
>>> net = _get_train_network() >>> val_net = _get_test_network() >>> val_net.share_parameters(net.collect_params()) >>> net.initialize(device=device) >>> est = Estimator(net, loss, val_net=val_net)
Proper namespace match is required for weight sharing between two networks. Most networks inheriting
Blockcan share their parameters correctly. An exception is Sequential networks that Block scope must be specified for correct weight sharing. For the naming in mxnet Gluon API, please refer to the site (https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html) for future information.val_loss (gluon.loss.loss) – Loss (objective) function to calculate during validation. If set val_loss None, it will use the same loss function as self.loss
batch_processor (BatchProcessor) – BatchProcessor provides customized fit_batch() and evaluate_batch() methods
- evaluate(val_data, batch_axis=0, event_handlers=None)[source]¶
Evaluate model on validation data.
This function calls
evaluate_batch()on each of the batches from the validation data loader. Thus, for custom use cases, it’s possible to inherit the estimator class and overrideevaluate_batch().- Parameters:
val_data (DataLoader) – Validation data loader with data and labels.
batch_axis (int, default 0) – Batch axis to split the validation data into devices.
event_handlers (EventHandler or list of EventHandler) – List of
EventHandlersto apply during validation. Besides event handlers specified here, a default MetricHandler and a LoggingHandler will be added if not specified explicitly.
- fit(train_data, val_data=None, epochs=None, event_handlers=None, batches=None, batch_axis=0)[source]¶
Trains the model with a given
DataLoaderfor a specified number of epochs or batches. The batch size is inferred from the data loader’s batch_size.This function calls
fit_batch()on each of the batches from the training data loader. Thus, for custom use cases, it’s possible to inherit the estimator class and overridefit_batch().- Parameters:
train_data (DataLoader) – Training data loader with data and labels.
val_data (DataLoader, default None) – Validation data loader with data and labels.
epochs (int, default None) – Number of epochs to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches).
event_handlers (EventHandler or list of EventHandler) – List of
EventHandlersto apply during training. Besides the event handlers specified here, a StoppingHandler, LoggingHandler and MetricHandler will be added by default if not yet specified manually. If validation data is provided, a ValidationHandler is also added if not already specified.batches (int, default None) – Number of batches to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches).
batch_axis (int, default 0) – Batch axis to split the training data into devices.
- logger = None¶
logging.Logger object associated with the Estimator.
The logger is used for all logs generated by this estimator and its handlers. A new logging.Logger is created during Estimator construction and configured to write all logs with level logging.INFO or higher to sys.stdout.
You can modify the logging settings using the standard Python methods. For example, to save logs to a file in addition to printing them to stdout output, you can attach a logging.FileHandler to the logger.
>>> est = Estimator(net, loss) >>> import logging >>> est.logger.addHandler(logging.FileHandler(filename))