SINGD: Structured Inverse-Free Natural Gradient Descent
This package contains the official PyTorch implementation of our memory-efficient and numerically stable KFAC variant, termed SINGD (paper).
The main feature is a torch.optim.Optimizer which works like most PyTorch optimizers and is compatible with:
- Per-parameter
options
(
param_groups) -
Using a learning rate scheduler
-
Gradient accumulation (multiple forward-backwards, then take a step)
-
Distributed data-parallel (DDP) training1
The pre-conditioner matrices support different structures that allow to reduce cost (overview).
Installation
-
Stable (recommended):
-
Latest version from GitHub
mainbranch:
Usage
Limitations
-
SINGDdoes not support graph neural networks (GNN). -
SINGDcurrently does not support gradient clipping. -
The code has stabilized only recently. Expect things to break and help us improve by filing issues.
Citation
If you find this code useful for your research, consider citing the paper:
@inproceedings{lin2024structured,
title = {Structured Inverse-Free Natural Gradient Descent:
Memory-Efficient \& Numerically-Stable {KFAC}},
author = {Wu Lin and Felix Dangel and Runa Eschenhagen and Kirill
Neklyudov and Agustinus Kristiadi and Richard E. Turner and
Alireza Makhzani},
booktitle = {International Conference on Machine Learning (ICML)},
year = 2024,
}
-
We do support standard DDP with one crucial difference: The model should not be wrapped with the DDP wrapper, but the rest, e.g. using the
torchruncommand stays the same. ↩