Note
Go to the end to download the full example code
Training a Poincare ResNet¶
This is an implementation based on the Poincare Resnet paper, which can be found at:
Due to the complexity of hyperbolic operations we strongly advise to only run this tutorial with a GPU.
We will perform the following steps in order:
Define a hyperbolic manifold
Load and normalize the CIFAR10 training and test datasets using
torchvisionDefine a Poincare ResNet
Define a loss function and optimizer
Train the network on the training data
Test the network on the test data
0. Grab the available device¶
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
1. Define the Poincare ball¶
from hypll.manifolds.poincare_ball import Curvature, PoincareBall
# Making the curvature a learnable parameter is usually suboptimal but can
# make training smoother. An initial curvature of 0.1 has also been shown
# to help during training.
manifold = PoincareBall(c=Curvature(value=0.1, requires_grad=True))
2. Load and normalize CIFAR10¶
import torchvision
import torchvision.transforms as transforms
Note
If running on Windows and you get a BrokenPipeError, try setting the num_worker of torch.utils.data.DataLoader() to 0.
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
batch_size = 128
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
3. Define a Poincare ResNet¶
This implementation is based on the Poincare ResNet paper, which can be found at https://arxiv.org/abs/2303.14027 and which, in turn, is based on the original Euclidean implementation described in the paper Deep Residual Learning for Image Recognition by He et al. from 2015: https://arxiv.org/abs/1512.03385.
from typing import Optional
from torch import nn
from hypll import nn as hnn
from hypll.tensors import ManifoldTensor
class PoincareResidualBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
manifold: PoincareBall,
stride: int = 1,
downsample: Optional[nn.Sequential] = None,
):
# We can replace each operation in the usual ResidualBlock by a manifold-agnostic
# operation and supply the PoincareBall object to these operations.
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.manifold = manifold
self.stride = stride
self.downsample = downsample
self.conv1 = hnn.HConvolution2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
manifold=manifold,
stride=stride,
padding=1,
)
self.bn1 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold)
self.relu = hnn.HReLU(manifold=self.manifold)
self.conv2 = hnn.HConvolution2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
manifold=manifold,
padding=1,
)
self.bn2 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold)
def forward(self, x: ManifoldTensor) -> ManifoldTensor:
residual = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.downsample is not None:
residual = self.downsample(residual)
# We replace the addition operation inside the skip connection by a Mobius addition.
x = self.manifold.mobius_add(x, residual)
x = self.relu(x)
return x
class PoincareResNet(nn.Module):
def __init__(
self,
channel_sizes: list[int],
group_depths: list[int],
manifold: PoincareBall,
):
# For the Poincare ResNet itself we again replace each layer by a manifold-agnostic one
# and supply the PoincareBall to each of these. We also replace the ResidualBlocks by
# the manifold-agnostic one defined above.
super().__init__()
self.channel_sizes = channel_sizes
self.group_depths = group_depths
self.manifold = manifold
self.conv = hnn.HConvolution2d(
in_channels=3,
out_channels=channel_sizes[0],
kernel_size=3,
manifold=manifold,
padding=1,
)
self.bn = hnn.HBatchNorm2d(features=channel_sizes[0], manifold=manifold)
self.relu = hnn.HReLU(manifold=manifold)
self.group1 = self._make_group(
in_channels=channel_sizes[0],
out_channels=channel_sizes[0],
depth=group_depths[0],
)
self.group2 = self._make_group(
in_channels=channel_sizes[0],
out_channels=channel_sizes[1],
depth=group_depths[1],
stride=2,
)
self.group3 = self._make_group(
in_channels=channel_sizes[1],
out_channels=channel_sizes[2],
depth=group_depths[2],
stride=2,
)
self.avg_pool = hnn.HAvgPool2d(kernel_size=8, manifold=manifold)
self.fc = hnn.HLinear(in_features=channel_sizes[2], out_features=10, manifold=manifold)
def forward(self, x: ManifoldTensor) -> ManifoldTensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.group1(x)
x = self.group2(x)
x = self.group3(x)
x = self.avg_pool(x)
x = self.fc(x.squeeze())
return x
def _make_group(
self,
in_channels: int,
out_channels: int,
depth: int,
stride: int = 1,
) -> nn.Sequential:
if stride == 1:
downsample = None
else:
downsample = hnn.HConvolution2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
manifold=self.manifold,
stride=stride,
)
layers = [
PoincareResidualBlock(
in_channels=in_channels,
out_channels=out_channels,
manifold=self.manifold,
stride=stride,
downsample=downsample,
)
]
for _ in range(1, depth):
layers.append(
PoincareResidualBlock(
in_channels=out_channels,
out_channels=out_channels,
manifold=self.manifold,
)
)
return nn.Sequential(*layers)
# Now, let's create a thin Poincare ResNet with channel sizes [4, 8, 16] and with a depth of 20
# layers.
net = PoincareResNet(
channel_sizes=[4, 8, 16],
group_depths=[3, 3, 3],
manifold=manifold,
).to(device)
4. Define a Loss function and optimizer¶
Let’s use a Classification Cross-Entropy loss and RiemannianAdam optimizer.
criterion = nn.CrossEntropyLoss()
# net.parameters() includes the learnable curvature "c" of the manifold.
from hypll.optim import RiemannianAdam
optimizer = RiemannianAdam(net.parameters(), lr=0.001)
5. Train the network¶
We simply have to loop over our data iterator, project the inputs onto the manifold, and feed them to the network and optimize. We will train for a limited number of epochs here due to the long training time of this model.
from hypll.tensors import TangentTensor
for epoch in range(2): # Increase this number to at least 100 for good results
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# move the inputs to the manifold
tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold)
manifold_inputs = manifold.expmap(tangents)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(manifold_inputs)
loss = criterion(outputs.tensor, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
print(f"[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}")
running_loss = 0.0
print("Finished Training")
6. Test the network on the test data¶
Let us look at how the network performs on the whole dataset.
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
# move the images to the manifold
tangents = TangentTensor(data=images, man_dim=1, manifold=manifold)
manifold_images = manifold.expmap(tangents)
# calculate outputs by running images through the network
outputs = net(manifold_images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.tensor, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")