Skip to content

API Documentation

singd.optim.optimizer.SINGD

SINGD(model: Module, params: Union[None, Iterable[Parameter], List[Dict[str, Any]]] = None, lr: float = 0.001, momentum: float = 0.9, damping: float = 0.001, alpha1: float = 0.5, weight_decay: float = 0.0, T: int = 10, loss_average: Union[None, str] = 'batch', lr_cov: Union[float, Callable[[int], float]] = 0.01, structures: Tuple[str, str] = ('dense', 'dense'), kfac_approx: str = 'expand', warn_unsupported: bool = True, kfac_like: bool = False, preconditioner_dtype: Tuple[Union[dtype, None], Union[dtype, None]] = (None, None), init_grad_scale: float = 1.0, normalize_lr_cov: bool = False)

Bases: Optimizer

Structured inverse-free natural gradient descent.

The algorithm is introduced in this paper and extends the inverse-free KFAC algorithm from Lin et al. (ICML 2023) with structured pre-conditioner matrices.

Note

Uses the empirical Fisher.

Note

(Implementation concept) The optimizer installs a single forward hook on known modules. During a forward pass, this hook installs a tensor hook on the layer's output which computes the quantities required for the pre-conditioner. During .step, these quantities will be flushed to update the pre-conditioner, compute the approximate natural gradient, and update the network parameters.

Attributes:

  • SUPPORTED_STRUCTURES (Dict[str, Type[StructuredMatrix]]) –

    A string-to-class mapping of supported structures.

  • SUPPORTED_MODULES (Tuple[Type[Module], ...]) –

    Supported layers.

  • STATE_ATTRIBUTES (List[str]) –

    Attributes that belong to the optimizer's state but are not stored inside the self.state attribute. They will be saved and restored when the optimizer is check-pointed (by calling .state_dict() and .load_state_dict()).

  • SUPPORTED_LOSS_AVERAGE (Tuple[Union[None, str], ...]) –

    Supported loss averaging schemes.

  • _step_supports_amp_scaling

    Indicates that step handles gradient scaling internally if the optimizer is used together with a torch.cuda.amp.GradScaler. Before calling this class's .step(), the gradient scaler will store the current gradient scale inside .grad_scale, and whether infs occur in the gradients in .found_inf. For details, see the implementation of torch.cuda.amp.GradScaler.step.

Structured inverse-free natural gradient descent optimizer.

Uses the empirical Fisher. See the paper for the notation.

Parameters:

  • model (Module) –

    The neural network whose parameters (or a subset thereof) will be trained.

  • params (Union[None, Iterable[Parameter], List[Dict[str, Any]]], default: None ) –

    Used to specify the trainable parameters or parameter groups. If unspecified, all parameters of model which are supported by the optimizer will be trained. If a list of Parameters is passed, only these parameters will be trained. If a list of dictionaries is passed, these will be used as parameter groups.

  • lr (float, default: 0.001 ) –

    (\(\beta_2\) in the paper) Learning rate for the parameter updates. Default: 0.001.

  • momentum (float, default: 0.9 ) –

    (\(\alpha_2\) in the paper) Momentum on the parameter updates. Default: 0.9.

  • damping (float, default: 0.001 ) –

    (\(\lambda\) in the paper) Damping strength used in the updates of the pre-conditioner momenta \(\mathbf{m}_\mathbf{K}\) and \(\mathbf{m}_\mathbf{C}\). Default: 0.001.

  • alpha1 (float, default: 0.5 ) –

    (\(\alpha_1\) in the paper) Momentum used in the updates of the pre-conditioner momenta \(\mathbf{m}_\mathbf{K}\) and \(\mathbf{m}_\mathbf{C}\). Default: 0.5.

  • weight_decay (float, default: 0.0 ) –

    (\(\gamma\) in the paper) Weight decay on the parameters. Default: 0.0.

  • T (int, default: 10 ) –

    Pre-conditioner update frequency. Default: 10.

  • loss_average (Union[None, str], default: 'batch' ) –

    Whether the loss function is a mean over per-sample losses and if yes, over which dimensions the mean is taken. If "batch", the loss function is a mean over as many terms as the size of the mini-batch. If "batch+sequence", the loss function is a mean over as many terms as the size of the mini-batch times the sequence length, e.g. in the case of language modeling. If None, the loss function is a sum. This argument is used to ensure that the preconditioner is scaled consistently with the loss and the gradient. Default: "batch".

  • lr_cov (Union[float, Callable[[int], float]], default: 0.01 ) –

    (β₁ in the paper) Learning rate for the updates of the pre- conditioner momenta \(\mathbf{m}_\mathbf{K}\) and \(\mathbf{m}_\mathbf{C}\). Default is 1e-2. Also allows for a callable which takes the current step and returns the current value for lr_cov. Using a too large value during the first few steps might lead to instabilities because the pre-conditioner is still warming up. In that case, try using a schedule which gradually ramps up lr_cov. Or use a constant value and turn on normalize_lr_cov which will at most use lr_cov during training.

  • structures (Tuple[str, str], default: ('dense', 'dense') ) –

    A 2-tuple of strings specifying the structure of the pre-conditioner matrices \(\mathbf{K}, \mathbf{C}\) and their momenta \(\mathbf{m}_\mathbf{K}, \mathbf{m}_\mathbf{C}\). Possible values for each entry are 'dense', 'diagonal', 'block30diagonal', 'hierarchical15_15', 'triltoeplitz', and 'triutoeplitz'. Default is ('dense', 'dense').

  • kfac_approx (str, default: 'expand' ) –

    A string specifying the KFAC approximation that should be used for linear weight-sharing layers, e.g. Conv2d modules or Linear modules that process matrix- or higher-dimensional features. Possible values are 'expand' and 'reduce'. See Eschenhagen et al., 2023 for an explanation of the two approximations.

  • warn_unsupported (bool, default: True ) –

    Only relevant if params is unspecified. Whether to warn if model contains parameters of layers that are not supported. These parameters will not be trained by the optimizer. Default: True

  • kfac_like (bool, default: False ) –

    Whether to use the modified update rule which results in an update close to the KFAC optimizer (IKFAC). Default: False. Please see the theorem in the paper for more details.

  • preconditioner_dtype (Tuple[Union[dtype, None], Union[dtype, None]], default: (None, None) ) –

    Data types used to store the structured pre-conditioner matrices \(\mathbf{K}, \mathbf{C}\) and their momenta \(\mathbf{m}_\mathbf{K}, \mathbf{m}_\mathbf{C}\). If None, will use the same data type as the parameter for both pre-conditioner matrices and momenta. If (float32, None), will use float32 for \(\mathbf{K}, \mathbf{m}_\mathbf{K}\) and the same data type as the weight for \(\mathbf{C}, \mathbf{m}_\mathbf{C}\). Default: (None, None).

  • init_grad_scale (float, default: 1.0 ) –

    Only relevant if using a torch.amp.GradScaler. Initial gradient scale of the scaler or a number of similar magnitude. If unspecified, the optimizer will still work correctly but the pre-conditioner compu- tation in the first backpropagation might be numerically unstable. Default: 1.0.

  • normalize_lr_cov (bool, default: False ) –

    Use normalized gradient descent to update the pre-conditioner factors. Enabling this is a good alternative to scheduling lr_cov as we found it to improve SINGD's stability in the early phase where the pre-conditioners are still warming up. Default: False. Requires an additional matrix norm computation which will be used to adapt lr_cov. (Details: To update the pre-conditioner, SINGD performs Riemannian gradient descent (RGD) on the pre-conditioner factors. Since it uses Riemannian normal coordinates RGD reduces to GD. This allows to apply the idea of normalized gradient descent.)

Raises:

  • TypeError

    If DataParallel or DistributedDataParallel model wrappers are used.

  • ValueError

    If any of the learning rate and momentum parameters (lr, lr_cov, alpha1, momentum, weight_decay) are non-positive.

Source code in singd/optim/optimizer.py
    def __init__(
        self,
        model: Module,
        params: Union[None, Iterable[Parameter], List[Dict[str, Any]]] = None,
        lr: float = 0.001,  # β₂ in the paper
        momentum: float = 0.9,  # α₂ in the paper
        damping: float = 0.001,  # λ in the paper
        alpha1: float = 0.5,  # α₁ in the paper
        weight_decay: float = 0.0,  # γ in the paper
        T: int = 10,  # T in the paper
        loss_average: Union[None, str] = "batch",
        lr_cov: Union[float, Callable[[int], float]] = 1e-2,  # β₁ in the paper
        structures: Tuple[str, str] = ("dense", "dense"),
        kfac_approx: str = "expand",
        warn_unsupported: bool = True,
        kfac_like: bool = False,
        preconditioner_dtype: Tuple[Union[dtype, None], Union[dtype, None]] = (
            None,
            None,
        ),
        init_grad_scale: float = 1.0,
        normalize_lr_cov: bool = False,
    ):  # noqa: D301
        """Structured inverse-free natural gradient descent optimizer.

        Uses the empirical Fisher. See the [paper](http://arxiv.org/abs/2312.05705) for
        the notation.

        Args:
            model: The neural network whose parameters (or a subset thereof) will be
                trained.
            params: Used to specify the trainable parameters or parameter groups.
                If unspecified, all parameters of `model` which are supported by the
                optimizer will be trained. If a list of `Parameters` is passed,
                only these parameters will be trained. If a list of dictionaries is
                passed, these will be used as [parameter groups](\
https://pytorch.org/docs/stable/optim.html#per-parameter-options).
            lr: (\\(\\beta_2\\) in the paper) Learning rate for the parameter updates.
                Default: `0.001`.
            momentum: (\\(\\alpha_2\\) in the paper) Momentum on the parameter updates.
                Default: `0.9`.
            damping: (\\(\\lambda\\) in the paper) Damping strength used in the updates
                of the pre-conditioner momenta \\(\\mathbf{m}_\\mathbf{K}\\) and
                \\(\\mathbf{m}_\\mathbf{C}\\). Default: `0.001`.
            alpha1: (\\(\\alpha_1\\) in the paper) Momentum used in the updates
                of the pre-conditioner momenta \\(\\mathbf{m}_\\mathbf{K}\\) and
                \\(\\mathbf{m}_\\mathbf{C}\\). Default: `0.5`.
            weight_decay: (\\(\\gamma\\) in the paper) Weight decay on the parameters.
                Default: `0.0`.
            T: Pre-conditioner update frequency. Default: `10`.
            loss_average: Whether the loss function is a mean over per-sample
                losses and if yes, over which dimensions the mean is taken.
                If `"batch"`, the loss function is a mean over as many terms as
                the size of the mini-batch. If `"batch+sequence"`, the loss
                function is a mean over as many terms as the size of the
                mini-batch times the sequence length, e.g. in the case of
                language modeling. If `None`, the loss function is a sum. This
                argument is used to ensure that the preconditioner is scaled
                consistently with the loss and the gradient. Default: `"batch"`.
            lr_cov: (β₁ in the paper) Learning rate for the updates of the pre-
                conditioner momenta \\(\\mathbf{m}_\\mathbf{K}\\) and
                \\(\\mathbf{m}_\\mathbf{C}\\). Default is `1e-2`. Also allows for a
                callable which takes the current step and returns the current value for
                `lr_cov`. Using a too large value during the first few steps might lead
                to instabilities because the pre-conditioner is still warming up. In
                that case, try using a schedule which gradually ramps up `lr_cov`. Or
                use a constant value and turn on `normalize_lr_cov` which will at most
                use `lr_cov` during training.
            structures: A 2-tuple of strings specifying the structure of the
                pre-conditioner matrices \\(\\mathbf{K}, \\mathbf{C}\\) and their
                momenta \\(\\mathbf{m}_\\mathbf{K}, \\mathbf{m}_\\mathbf{C}\\).
                Possible values for each entry are `'dense'`, `'diagonal'`,
                `'block30diagonal'`, `'hierarchical15_15'`, `'triltoeplitz'`, and
                `'triutoeplitz'`. Default is (`'dense'`, `'dense'`).
            kfac_approx: A string specifying the KFAC approximation that should
                be used for linear weight-sharing layers, e.g. `Conv2d` modules
                or `Linear` modules that process matrix- or higher-dimensional
                features.
                Possible values are `'expand'` and `'reduce'`.
                See [Eschenhagen et al., 2023](https://arxiv.org/abs/2311.00636)
                for an explanation of the two approximations.
            warn_unsupported: Only relevant if `params` is unspecified. Whether to
                warn if `model` contains parameters of layers that are not supported.
                These parameters will not be trained by the optimizer. Default: `True`
            kfac_like: Whether to use the modified update rule which results in an
                update close to the KFAC optimizer (IKFAC). Default: `False`.
                Please see the theorem in the paper for more details.
            preconditioner_dtype: Data types used to store the structured
                pre-conditioner matrices \\(\\mathbf{K}, \\mathbf{C}\\) and their
                momenta \\(\\mathbf{m}_\\mathbf{K}, \\mathbf{m}_\\mathbf{C}\\).
                If `None`, will use the same data type as the parameter for both
                pre-conditioner matrices and momenta. If `(float32, None)`, will use
                `float32` for \\(\\mathbf{K}, \\mathbf{m}_\\mathbf{K}\\) and the same
                data type as the weight for \\(\\mathbf{C}, \\mathbf{m}_\\mathbf{C}\\).
                Default: `(None, None)`.
            init_grad_scale: Only relevant if using a [`torch.amp.GradScaler`](\
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler). Initial gradient
                scale of the scaler or a number of similar magnitude. If unspecified,
                the optimizer will still work correctly but the pre-conditioner compu-
                tation in the first backpropagation might be numerically unstable.
                Default: `1.0`.
            normalize_lr_cov: Use [normalized gradient descent](\
https://arxiv.org/abs/1711.05224) to update the pre-conditioner factors. Enabling this
                is a good alternative to scheduling `lr_cov` as we found it to improve
                SINGD's stability in the early phase where the pre-conditioners are
                still warming up. Default: `False`. Requires an additional matrix norm
                computation which will be used to adapt `lr_cov`.
                (Details: To update the pre-conditioner, SINGD performs Riemannian
                gradient descent (RGD) on the pre-conditioner factors. Since it uses
                Riemannian normal coordinates RGD reduces to GD. This allows to apply
                the idea of normalized gradient descent.)

        Raises:
            TypeError: If `DataParallel` or `DistributedDataParallel` model wrappers
                are used.
            ValueError: If any of the learning rate and momentum parameters
                (`lr, lr_cov, alpha1, momentum, weight_decay`) are non-positive.
        """
        if isinstance(model, (DP, DDP)):
            raise TypeError(
                "DataParallel and DistributedDataParallel wrappers are not supported. "
                "Use the normal DDP setup without the wrapper for distributed training."
            )

        for x, name in [
            (lr, "lr"),
            (lr_cov, "lr_cov"),
            (alpha1, "alpha1"),
            (momentum, "momentum"),
            (weight_decay, "weight_decay"),
        ]:
            if isinstance(x, float) and x < 0.0:
                raise ValueError(f"{name} must be positive. Got {x}")

        defaults = dict(
            lr=lr,
            momentum=momentum,
            damping=damping,
            alpha1=alpha1,
            weight_decay=weight_decay,
            T=T,
            loss_average=loss_average,
            lr_cov=lr_cov,
            structures=structures,
            kfac_approx=kfac_approx,
            kfac_like=kfac_like,
            preconditioner_dtype=preconditioner_dtype,
            normalize_lr_cov=normalize_lr_cov,
        )
        if params is None:
            params = self._get_trainable_parameters(
                model, warn_unsupported=warn_unsupported
            )
        super().__init__(params, defaults)
        self.steps = 0

        # for mapping modules to their groups
        self.param_to_group_idx = self._check_param_groups(model)
        # layers whose parameters will be updated
        self.module_names, self.hook_handles = self._install_hooks(model)

        # NOTE We use the module names (strings) as keys as they don't change when a
        # model is loaded from a checkpoint (unlike the module objects themselves).

        # store matrices for the pre-conditioner
        self.Ks: Dict[str, StructuredMatrix] = {}
        self.Cs: Dict[str, StructuredMatrix] = {}

        # store momentum terms for the pre-conditioner matrices
        self.m_Ks: Dict[str, StructuredMatrix] = {}
        self.m_Cs: Dict[str, StructuredMatrix] = {}

        # store accumulated H_Ks and H_Cs from one/multiple backward passes
        self.H_Ks: Dict[str, StructuredMatrix] = {}
        self.H_Cs: Dict[str, StructuredMatrix] = {}

        self._initialize_buffers()

        # Book-keeping of `grad_scale`s. We need to keep track of scales of two
        # consecutive steps because we do not have access to the scale at step `t`
        # during its backpropagation, but only when updating the pre-conditioner. Our
        # solution is to un-scale the gradient with the scale from step `t-1` in the
        # backward hook computations for step `t`, then undo and use the scale from
        # step `t` in the pre-conditioner update.
        self._grad_scales: Dict[int, float] = {-1: init_grad_scale}