Skip to content

Note

Click here to download the full example code

Basic usage.

This example demonstrates the simplest usage of SINGD. The algorithm works pretty much like any other torch.optim.Optimizer; but there are some additional aspects that are good to know.

First, the imports.

from torch import cuda, device, manual_seed
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from singd.optim.optimizer import SINGD

manual_seed(0)  # make deterministic
MAX_STEPS = 200  # quit training after this many steps (or one epoch)
DEV = device("cuda" if cuda.is_available() else "cpu")

Problem Setup

We will train a simple neural network on MNIST using cross-entropy loss:

BATCH_SIZE = 32
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

model = Sequential(
    Conv2d(1, 3, kernel_size=5, stride=2),
    ReLU(),
    Flatten(),
    Linear(432, 50),
    ReLU(),
    Linear(50, 10),
).to(DEV)
loss_func = CrossEntropyLoss().to(DEV)

Out:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

 79%|#######9  | 7864320/9912422 [00:00<00:00, 73575456.89it/s]
100%|##########| 9912422/9912422 [00:00<00:00, 83250823.97it/s]
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]
100%|##########| 28881/28881 [00:00<00:00, 133851595.39it/s]
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]
100%|##########| 1648877/1648877 [00:00<00:00, 32185050.18it/s]
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]
100%|##########| 4542/4542 [00:00<00:00, 28518755.64it/s]
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Optimizer Setup

One difference to many built-in PyTorch optimizers is that SINGD requires access to the neural network (a torch.nn.Module):

singd = SINGD(model)

This is because SINGD needs to install hooks onto some of the neural network's layers to carry out the additional computations for its pre-conditioner.

Of course, you can also tweak SINGD's other arguments, such as learning rates and momenta. See here for a complete overview.

Training

When it comes to training, SINGD can be used in exactly the same way as other optimizers (see here for an introduction). Let's train for a couple of steps and print the loss.

PRINT_LOSS_EVERY = 25  # logging interval

for step, (inputs, target) in enumerate(train_loader):
    singd.zero_grad()  # clear gradients from previous iterations

    # regular forward-backward pass
    loss = loss_func(model(inputs.to(DEV)), target.to(DEV))
    loss.backward()
    if step % PRINT_LOSS_EVERY == 0:
        print(f"Step {step}, Loss {loss.item():.3f}")

    singd.step()  # update neural network parameters

    if step >= MAX_STEPS:  # don't train a full epoch to keep the example light-weight
        break

Out:

Step 0, Loss 2.317
Step 25, Loss 2.322
Step 50, Loss 2.312
Step 75, Loss 2.298
Step 100, Loss 2.279
Step 125, Loss 2.237
Step 150, Loss 2.268
Step 175, Loss 2.186
Step 200, Loss 2.192

Conclusion

You now know the most basic way to train a neural network with SINGD. From here, you might be interested in

Total running time of the script: ( 0 minutes 5.606 seconds)

Download Python source code: example_01_basic.py

Download Jupyter notebook: example_01_basic.ipynb

Gallery generated by mkdocs-gallery