Training a Hyperbolic Classifier

This is an adaptation of torchvision’s tutorial “Training a Classifier” to hyperbolic space. The original tutorial can be found here:

Training a Hyperbolic Image Classifier

We will do the following steps in order:

  1. Define a hyperbolic manifold

  2. Load and normalize the CIFAR10 training and test datasets using torchvision

  3. Define a hyperbolic Convolutional Neural Network

  4. Define a loss function and optimizer

  5. Train the network on the training data

  6. Test the network on the test data

1. Define a hyperbolic manifold

We use the Poincaré ball model for the purposes of this tutorial.

from hypll.manifolds.poincare_ball import Curvature, PoincareBall

# Making the curvature a learnable parameter is usually suboptimal but can
# make training smoother.
manifold = PoincareBall(c=Curvature(requires_grad=True))

2. Load and normalize CIFAR10

import torch
import torchvision
import torchvision.transforms as transforms

Note

If running on Windows and you get a BrokenPipeError, try setting the num_worker of torch.utils.data.DataLoader() to 0.

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)


batch_size = 4

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2
)

classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

3. Define a hyperbolic Convolutional Neural Network

Let’s rebuild the convolutional neural network from torchvision’s tutorial using hyperbolic modules.

from torch import nn

from hypll import nn as hnn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = hnn.HConvolution2d(
            in_channels=3, out_channels=6, kernel_size=5, manifold=manifold
        )
        self.pool = hnn.HMaxPool2d(kernel_size=2, manifold=manifold, stride=2)
        self.conv2 = hnn.HConvolution2d(
            in_channels=6, out_channels=16, kernel_size=5, manifold=manifold
        )
        self.fc1 = hnn.HLinear(in_features=16 * 5 * 5, out_features=120, manifold=manifold)
        self.fc2 = hnn.HLinear(in_features=120, out_features=84, manifold=manifold)
        self.fc3 = hnn.HLinear(in_features=84, out_features=10, manifold=manifold)
        self.relu = hnn.HReLU(manifold=manifold)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.flatten(start_dim=1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

4. Define a Loss function and optimizer

Let’s use a Classification Cross-Entropy loss and RiemannianAdam optimizer. Adam is preferred because hyperbolic linear layers can sometimes have training difficulties early on due to poor initialization.

criterion = nn.CrossEntropyLoss()
# net.parameters() includes the learnable curvature "c" of the manifold.
from hypll.optim import RiemannianAdam

optimizer = RiemannianAdam(net.parameters(), lr=0.001)

5. Train the network

This is when things start to get interesting. We simply have to loop over our data iterator, project the inputs onto the manifold, and feed them to the network and optimize.

from hypll.tensors import TangentTensor

for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # move the inputs to the manifold
        tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold)
        manifold_inputs = manifold.expmap(tangents)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(manifold_inputs)
        loss = criterion(outputs.tensor, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:  # print every 2000 mini-batches
            print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
            running_loss = 0.0

print("Finished Training")

Let’s quickly save our trained model:

PATH = "./cifar_net.pth"
torch.save(net.state_dict(), PATH)

Next, let’s load back in our saved model (note: saving and re-loading the model wasn’t necessary here, we only did it to illustrate how to do so):

net = Net()
net.load_state_dict(torch.load(PATH))

6. Test the network on the test data

Let us look at how the network performs on the whole dataset.

correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data

        # move the images to the manifold
        tangents = TangentTensor(data=images, man_dim=1, manifold=manifold)
        manifold_images = manifold.expmap(tangents)

        # calculate outputs by running images through the network
        outputs = net(manifold_images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.tensor, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")

That looks way better than chance, which is 10% accuracy (randomly picking a class out of 10 classes). Seems like the network learnt something.

Hmmm, what are the classes that performed well, and the classes that did not perform well:

# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testloader:
        images, labels = data

        # move the images to the manifold
        tangents = TangentTensor(data=images, man_dim=1, manifold=manifold)
        manifold_images = manifold.expmap(tangents)

        outputs = net(manifold_images)
        _, predictions = torch.max(outputs.tensor, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")

Training on GPU

Just like how you transfer a Tensor onto the GPU, you transfer the neural net onto the GPU.

Let’s first define our device as the first visible cuda device if we have CUDA available:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

The rest of this section assumes that device is a CUDA device.

Then these methods will recursively go over all modules and convert their parameters and buffers to CUDA tensors:

net.to(device)

Remember that you will have to send the inputs and targets at every step to the GPU too:

inputs, labels = data[0].to(device), data[1].to(device)

Goals achieved:

  • Train a small hyperbolic neural network to classify images.

Gallery generated by Sphinx-Gallery