mxnet.optimizer.sgd¶
SGD optimizer
Classes
|
The SGD optimizer with momentum and weight decay. |
- class mxnet.optimizer.sgd.SGD(learning_rate=0.1, momentum=0.0, lazy_update=False, multi_precision=False, use_fused_step=True, aggregate_num=1, **kwargs)[source]¶
Bases:
OptimizerThe SGD optimizer with momentum and weight decay.
If the storage types of grad is
row_sparseandlazy_updateis True, lazy updates are applied by:for row in grad.indices: rescaled_grad[row] = clip(rescale_grad * grad[row] + wd * weight[row], clip_gradient) state[row] = momentum[row] * state[row] + lr * rescaled_grad[row] weight[row] = weight[row] - state[row]
The sparse update only updates the momentum for the weights whose row_sparse gradient indices appear in the current batch, rather than updating it for all indices. Compared with the original update, it can provide large improvements in model training throughput for some applications. However, it provides slightly different semantics than the original update, and may lead to different empirical results.
In the case when
update_on_kvstoreis set to False (either globally via MXNET_UPDATE_ON_KVSTORE=0 environment variable or as a parameter inTrainer) SGD optimizer can perform aggregated update of parameters, which may lead to improved performance. The aggregation size is controlled byaggregate_numand defaults to 4.Otherwise, standard updates are applied by:
rescaled_grad = clip(rescale_grad * grad, clip_gradient)) + wd * weight state = momentum * state + lr * rescaled_grad weight = weight - state
For details of the update algorithm see
sgd_updateandsgd_mom_update.This optimizer accepts the following parameters in addition to those accepted by
Optimizer.- Parameters:
learning_rate (float, default 0.1) – The initial learning rate. If None, the optimization will use the learning rate from
lr_scheduler. If not None, it will overwrite the learning rate inlr_scheduler. If None andlr_scheduleris also None, then it will be set to 0.01 by default.momentum (float, default 0.) – The momentum value.
lazy_update (bool, default False) – Default is False. If True, lazy updates are applied if the storage types of weight and grad are both
row_sparse.multi_precision (bool, default False) – Flag to control the internal precision of the optimizer. False: results in using the same precision as the weights (default), True: makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.
aggregate_num (int, default 1) – Number of weights to be aggregated in a list. They are passed to the optimizer for a single optimization step.
use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.
- create_state(index, weight)[source]¶
Creates auxiliary state for a given weight.
Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.
- fused_step(indices, weights, grads, states)[source]¶
Perform a fused optimization step using gradients and states. Fused kernel is used for update.
- Parameters:
indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.
weights (list of NDArray) – List of parameters to be updated.
grads (list of NDArray) – List of gradients of the objective with respect to this parameter.
states (List of any obj) – List of state returned by create_state().
- step(indices, weights, grads, states)[source]¶
Perform an optimization step using gradients and states.
- Parameters:
indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.
weights (list of NDArray) – List of parameters to be updated.
grads (list of NDArray) – List of gradients of the objective with respect to this parameter.
states (List of any obj) – List of state returned by create_state().