API Documentation
singd.optim.optimizer.SINGD
SINGD(model: Module, params: Union[None, Iterable[Parameter], List[Dict[str, Any]]] = None, lr: float = 0.001, momentum: float = 0.9, damping: float = 0.001, alpha1: float = 0.5, weight_decay: float = 0.0, T: int = 10, loss_average: Union[None, str] = 'batch', lr_cov: Union[float, Callable[[int], float]] = 0.01, structures: Tuple[str, str] = ('dense', 'dense'), kfac_approx: str = 'expand', warn_unsupported: bool = True, kfac_like: bool = False, preconditioner_dtype: Tuple[Union[dtype, None], Union[dtype, None]] = (None, None), init_grad_scale: float = 1.0, normalize_lr_cov: bool = False)
Bases: Optimizer
Structured inverse-free natural gradient descent.
The algorithm is introduced in this paper and extends the inverse-free KFAC algorithm from Lin et al. (ICML 2023) with structured pre-conditioner matrices.
Note
Uses the empirical Fisher.
Note
(Implementation concept) The optimizer installs a single forward hook on known
modules. During a forward pass, this hook installs a tensor hook on the layer's
output which computes the quantities required for the pre-conditioner.
During .step
, these quantities will be flushed to update the pre-conditioner,
compute the approximate natural gradient, and update the network parameters.
Attributes:
-
SUPPORTED_STRUCTURES
(Dict[str, Type[StructuredMatrix]]
) –A string-to-class mapping of supported structures.
-
SUPPORTED_MODULES
(Tuple[Type[Module], ...]
) –Supported layers.
-
STATE_ATTRIBUTES
(List[str]
) –Attributes that belong to the optimizer's state but are not stored inside the
self.state
attribute. They will be saved and restored when the optimizer is check-pointed (by calling.state_dict()
and.load_state_dict()
). -
SUPPORTED_LOSS_AVERAGE
(Tuple[Union[None, str], ...]
) –Supported loss averaging schemes.
-
_step_supports_amp_scaling
–Indicates that
step
handles gradient scaling internally if the optimizer is used together with atorch.cuda.amp.GradScaler
. Before calling this class's.step()
, the gradient scaler will store the current gradient scale inside.grad_scale
, and whetherinfs
occur in the gradients in.found_inf
. For details, see the implementation oftorch.cuda.amp.GradScaler.step
.
Structured inverse-free natural gradient descent optimizer.
Uses the empirical Fisher. See the paper for the notation.
Parameters:
-
model
(Module
) –The neural network whose parameters (or a subset thereof) will be trained.
-
params
(Union[None, Iterable[Parameter], List[Dict[str, Any]]]
, default:None
) –Used to specify the trainable parameters or parameter groups. If unspecified, all parameters of
model
which are supported by the optimizer will be trained. If a list ofParameters
is passed, only these parameters will be trained. If a list of dictionaries is passed, these will be used as parameter groups. -
lr
(float
, default:0.001
) –(\(\beta_2\) in the paper) Learning rate for the parameter updates. Default:
0.001
. -
momentum
(float
, default:0.9
) –(\(\alpha_2\) in the paper) Momentum on the parameter updates. Default:
0.9
. -
damping
(float
, default:0.001
) –(\(\lambda\) in the paper) Damping strength used in the updates of the pre-conditioner momenta \(\mathbf{m}_\mathbf{K}\) and \(\mathbf{m}_\mathbf{C}\). Default:
0.001
. -
alpha1
(float
, default:0.5
) –(\(\alpha_1\) in the paper) Momentum used in the updates of the pre-conditioner momenta \(\mathbf{m}_\mathbf{K}\) and \(\mathbf{m}_\mathbf{C}\). Default:
0.5
. -
weight_decay
(float
, default:0.0
) –(\(\gamma\) in the paper) Weight decay on the parameters. Default:
0.0
. -
T
(int
, default:10
) –Pre-conditioner update frequency. Default:
10
. -
loss_average
(Union[None, str]
, default:'batch'
) –Whether the loss function is a mean over per-sample losses and if yes, over which dimensions the mean is taken. If
"batch"
, the loss function is a mean over as many terms as the size of the mini-batch. If"batch+sequence"
, the loss function is a mean over as many terms as the size of the mini-batch times the sequence length, e.g. in the case of language modeling. IfNone
, the loss function is a sum. This argument is used to ensure that the preconditioner is scaled consistently with the loss and the gradient. Default:"batch"
. -
lr_cov
(Union[float, Callable[[int], float]]
, default:0.01
) –(β₁ in the paper) Learning rate for the updates of the pre- conditioner momenta \(\mathbf{m}_\mathbf{K}\) and \(\mathbf{m}_\mathbf{C}\). Default is
1e-2
. Also allows for a callable which takes the current step and returns the current value forlr_cov
. Using a too large value during the first few steps might lead to instabilities because the pre-conditioner is still warming up. In that case, try using a schedule which gradually ramps uplr_cov
. Or use a constant value and turn onnormalize_lr_cov
which will at most uselr_cov
during training. -
structures
(Tuple[str, str]
, default:('dense', 'dense')
) –A 2-tuple of strings specifying the structure of the pre-conditioner matrices \(\mathbf{K}, \mathbf{C}\) and their momenta \(\mathbf{m}_\mathbf{K}, \mathbf{m}_\mathbf{C}\). Possible values for each entry are
'dense'
,'diagonal'
,'block30diagonal'
,'hierarchical15_15'
,'triltoeplitz'
, and'triutoeplitz'
. Default is ('dense'
,'dense'
). -
kfac_approx
(str
, default:'expand'
) –A string specifying the KFAC approximation that should be used for linear weight-sharing layers, e.g.
Conv2d
modules orLinear
modules that process matrix- or higher-dimensional features. Possible values are'expand'
and'reduce'
. See Eschenhagen et al., 2023 for an explanation of the two approximations. -
warn_unsupported
(bool
, default:True
) –Only relevant if
params
is unspecified. Whether to warn ifmodel
contains parameters of layers that are not supported. These parameters will not be trained by the optimizer. Default:True
-
kfac_like
(bool
, default:False
) –Whether to use the modified update rule which results in an update close to the KFAC optimizer (IKFAC). Default:
False
. Please see the theorem in the paper for more details. -
preconditioner_dtype
(Tuple[Union[dtype, None], Union[dtype, None]]
, default:(None, None)
) –Data types used to store the structured pre-conditioner matrices \(\mathbf{K}, \mathbf{C}\) and their momenta \(\mathbf{m}_\mathbf{K}, \mathbf{m}_\mathbf{C}\). If
None
, will use the same data type as the parameter for both pre-conditioner matrices and momenta. If(float32, None)
, will usefloat32
for \(\mathbf{K}, \mathbf{m}_\mathbf{K}\) and the same data type as the weight for \(\mathbf{C}, \mathbf{m}_\mathbf{C}\). Default:(None, None)
. -
init_grad_scale
(float
, default:1.0
) –Only relevant if using a
torch.amp.GradScaler
. Initial gradient scale of the scaler or a number of similar magnitude. If unspecified, the optimizer will still work correctly but the pre-conditioner compu- tation in the first backpropagation might be numerically unstable. Default:1.0
. -
normalize_lr_cov
(bool
, default:False
) –Use normalized gradient descent to update the pre-conditioner factors. Enabling this is a good alternative to scheduling
lr_cov
as we found it to improve SINGD's stability in the early phase where the pre-conditioners are still warming up. Default:False
. Requires an additional matrix norm computation which will be used to adaptlr_cov
. (Details: To update the pre-conditioner, SINGD performs Riemannian gradient descent (RGD) on the pre-conditioner factors. Since it uses Riemannian normal coordinates RGD reduces to GD. This allows to apply the idea of normalized gradient descent.)
Raises:
-
TypeError
–If
DataParallel
orDistributedDataParallel
model wrappers are used. -
ValueError
–If any of the learning rate and momentum parameters (
lr, lr_cov, alpha1, momentum, weight_decay
) are non-positive.
Source code in singd/optim/optimizer.py
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
|