mxnet.gluon.trainer¶
Parameter optimizer.
Classes
|
Applies an Optimizer on a set of Parameters. |
- class mxnet.gluon.trainer.Trainer(params, optimizer, optimizer_params=None, kvstore='device', compression_params=None, update_on_kvstore=None)[source]¶
Bases:
objectApplies an Optimizer on a set of Parameters. Trainer should be used together with autograd.
Note
For the following cases, updates will always happen on kvstore, i.e., you cannot set update_on_kvstore=False.
dist kvstore with sparse weights or sparse gradients
dist async kvstore
optimizer.lr_scheduler is not None
- Parameters:
params (Dict) – The set of parameters to optimize.
optimizer (str or Optimizer) – The optimizer to use. See help on Optimizer for a list of available optimizers.
optimizer_params (dict) – Key-word arguments to be passed to optimizer constructor. For example, {‘learning_rate’: 0.1}. All optimizers accept learning_rate, wd (weight decay), clip_gradient, and lr_scheduler. See each optimizer’s constructor for a list of additional supported arguments.
kvstore (str or KVStore) – kvstore type for multi-gpu and distributed training. See help on
mxnet.kvstore.create()for more information.compression_params (dict) – Specifies type of gradient compression and additional arguments depending on the type of compression being used. For example, 2bit compression requires a threshold. Arguments would then be {‘type’:’2bit’, ‘threshold’:0.5} See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
update_on_kvstore (bool, default None) – Whether to perform parameter updates on kvstore. If None and optimizer.aggregate_num <= 1, then trainer will choose the more suitable option depending on the type of kvstore. If None and optimizer.aggregate_num > 1, update_on_kvstore is set to False. If the update_on_kvstore argument is provided, environment variable MXNET_UPDATE_ON_KVSTORE will be ignored.
Properties
----------
learning_rate (float) – The current learning rate of the optimizer. Given an Optimizer object optimizer, its learning rate can be accessed as optimizer.learning_rate.
- allreduce_grads()[source]¶
For each parameter, reduce the gradients from different devices.
Should be called after autograd.backward(), outside of record() scope, and before trainer.update().
For normal parameter updates, step() should be used, which internally calls allreduce_grads() and then update(). However, if you need to get the reduced gradients to perform certain transformation, such as in gradient clipping, then you may want to manually call allreduce_grads() and update() separately.
- load_states(fname)[source]¶
Loads trainer states (e.g. optimizer, momentum) from a file.
- Parameters:
fname (str) – Path to input states file.
Note
optimizer.param_dict, which contains Parameter information (such as lr_mult and wd_mult) will not be loaded from the file, but rather set based on current Trainer’s parameters.
- save_states(fname)[source]¶
Saves trainer states (e.g. optimizer, momentum) to a file.
- Parameters:
fname (str) – Path to output states file.
Note
optimizer.param_dict, which contains Parameter information (such as lr_mult and wd_mult) will not be saved.
- set_learning_rate(lr)[source]¶
Sets a new learning rate of the optimizer.
- Parameters:
lr (float) – The new learning rate of the optimizer.
- step(batch_size, ignore_stale_grad=False)[source]¶
Makes one step of parameter update. Should be called after autograd.backward() and outside of record() scope.
For normal parameter updates, step() should be used, which internally calls allreduce_grads() and then update(). However, if you need to get the reduced gradients to perform certain transformation, such as in gradient clipping, then you may want to manually call allreduce_grads() and update() separately.
Compared to torch.optim.SGD.step()
PyTorch’s
optimizer.step()takes no batch_size argument and applies the raw gradient produced byloss.backward(). Gluon’strainer.step(batch_size)rescales the gradient by1 / batch_sizebefore applying it, so for the SAME nominallrand the SAME per-sample loss formula, MXNet’s effective update isbatch_size× smaller than PyTorch’s. Concretely:# PyTorch loss(pred, y).sum().backward() # loss has shape (N,) optimizer.step() # applies full gradient # MXNet — PyTorch-equivalent loss(pred, y).backward() # MXNet implicitly sums for backward trainer.step(1) # NO rescale -> matches PyTorch # MXNet — common d2l convention (per-sample mean over batch) loss(pred, y).backward() trainer.step(N) # rescale by 1/N -> mean-over-batch
If your loss returns a per-sample value and you call
trainer.step(N), the effective learning rate isN× smaller than the equivalent PyTorch code with the same nominallr. When porting hyper-parameters between frameworks, either (a) usetrainer.step(1)to match PyTorch’s no-rescale behaviour, or (b) scalelrbyNto compensate.- Parameters:
batch_size (int) – Batch size of data processed. Gradient will be normalized by
1 / batch_size. Set this to1to disable rescaling — use1when your loss is already a per-batch mean, or when you are matching PyTorch’soptimizer.step()convention.ignore_stale_grad (bool, optional, default=False) – If true, ignores Parameters with stale gradient (gradient that has not been updated by backward after last step) and skip update.
- update(batch_size, ignore_stale_grad=False)[source]¶
Makes one step of parameter update.
Should be called after autograd.backward() and outside of record() scope, and after trainer.update().
For normal parameter updates, step() should be used, which internally calls allreduce_grads() and then update(). However, if you need to get the reduced gradients to perform certain transformation, such as in gradient clipping, then you may want to manually call allreduce_grads() and update() separately.
- Parameters:
batch_size (int) – Batch size of data processed. Gradient will be normalized by 1/batch_size. Set this to 1 if you normalized loss manually with loss = mean(loss).
ignore_stale_grad (bool, optional, default=False) – If true, ignores Parameters with stale gradient (gradient that has not been updated by backward after last step) and skip update.