Note
Click here to download the full example code
Basic usage.
This example demonstrates the simplest usage of SINGD
. The algorithm works
pretty much like any other
torch.optim.Optimizer
;
but there are some additional aspects that are good to know.
First, the imports.
from torch import cuda, device, manual_seed
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from singd.optim.optimizer import SINGD
manual_seed(0) # make deterministic
MAX_STEPS = 200 # quit training after this many steps (or one epoch)
DEV = device("cuda" if cuda.is_available() else "cpu")
Problem Setup
We will train a simple neural network on MNIST using cross-entropy loss:
BATCH_SIZE = 32
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
model = Sequential(
Conv2d(1, 3, kernel_size=5, stride=2),
ReLU(),
Flatten(),
Linear(432, 50),
ReLU(),
Linear(50, 10),
).to(DEV)
loss_func = CrossEntropyLoss().to(DEV)
Out:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
79%|#######9 | 7864320/9912422 [00:00<00:00, 73575456.89it/s]
100%|##########| 9912422/9912422 [00:00<00:00, 83250823.97it/s]
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
100%|##########| 28881/28881 [00:00<00:00, 133851595.39it/s]
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
100%|##########| 1648877/1648877 [00:00<00:00, 32185050.18it/s]
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
100%|##########| 4542/4542 [00:00<00:00, 28518755.64it/s]
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Optimizer Setup
One difference to many built-in PyTorch optimizers is that SINGD
requires
access to the neural network (a torch.nn.Module
):
This is because SINGD
needs to install hooks onto some of the neural
network's layers to carry out the additional computations for its pre-conditioner.
Of course, you can also tweak SINGD
's other arguments, such as learning
rates and momenta. See here
for a complete overview.
Training
When it comes to training, SINGD
can be used in exactly the same way as
other optimizers (see
here
for an introduction). Let's train for a couple of steps and print the loss.
PRINT_LOSS_EVERY = 25 # logging interval
for step, (inputs, target) in enumerate(train_loader):
singd.zero_grad() # clear gradients from previous iterations
# regular forward-backward pass
loss = loss_func(model(inputs.to(DEV)), target.to(DEV))
loss.backward()
if step % PRINT_LOSS_EVERY == 0:
print(f"Step {step}, Loss {loss.item():.3f}")
singd.step() # update neural network parameters
if step >= MAX_STEPS: # don't train a full epoch to keep the example light-weight
break
Out:
Step 0, Loss 2.317
Step 25, Loss 2.322
Step 50, Loss 2.312
Step 75, Loss 2.298
Step 100, Loss 2.279
Step 125, Loss 2.237
Step 150, Loss 2.268
Step 175, Loss 2.186
Step 200, Loss 2.192
Conclusion
You now know the most basic way to train a neural network with SINGD
. From
here, you might be interested in
-
checking out the more advanced examples
-
taking a closer look at
SINGD
s hyper-parameters.
Total running time of the script: ( 0 minutes 5.606 seconds)
Download Python source code: example_01_basic.py