mxnet.optimizer.adam

Adam optimizer.

Classes

Adam([learning_rate, beta1, beta2, epsilon, ...])

The Adam optimizer.

class mxnet.optimizer.adam.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=False, use_fused_step=True, **kwargs)[source]

Bases: Optimizer

The Adam optimizer.

This class implements the optimizer described in Adam: A Method for Stochastic Optimization, available at http://arxiv.org/abs/1412.6980.

If the storage types of grad is row_sparse, and lazy_update is True, lazy updates at step t are applied by:

for row in grad.indices:
    rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient) + wd * weight[row]
    m[row] = beta1 * m[row] + (1 - beta1) * rescaled_grad[row]
    v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2)
    lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t)
    w[row] = w[row] - lr * m[row] / (sqrt(v[row]) + epsilon)

The lazy update only updates the mean and var for the weights whose row_sparse gradient indices appear in the current batch, rather than updating it for all indices. Compared with the original update, it can provide large improvements in model training throughput for some applications. However, it provides slightly different semantics than the original update, and may lead to different empirical results.

Otherwise, standard updates at step t are applied by:

rescaled_grad = clip(grad * rescale_grad, clip_gradient) + wd * weight
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t)
w = w - lr * m / (sqrt(v) + epsilon)

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

For details of the update algorithm, see adam_update.

Parameters:
  • learning_rate (float, default 0.001) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • beta1 (float, default 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, default 0.999) – Exponential decay rate for the second moment estimates.

  • epsilon (float, default 1e-8) – Small value to avoid division by 0.

  • lazy_update (bool, default False) – Default is False. If True, lazy updates are applied if the storage types of weight and grad are both row_sparse.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters:
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns:

state – The state associated with the weight.

Return type:

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters:
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters:
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().