.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/cifar10_resnet_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_cifar10_resnet_tutorial.py: Training a Poincare ResNet ========================== This is an implementation based on the Poincare Resnet paper, which can be found at: - https://arxiv.org/abs/2303.14027 Due to the complexity of hyperbolic operations we strongly advise to only run this tutorial with a GPU. We will perform the following steps in order: 1. Define a hyperbolic manifold 2. Load and normalize the CIFAR10 training and test datasets using ``torchvision`` 3. Define a Poincare ResNet 4. Define a loss function and optimizer 5. Train the network on the training data 6. Test the network on the test data .. GENERATED FROM PYTHON SOURCE LINES 25-27 0. Grab the available device ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 27-32 .. code-block:: Python import torch device = "cuda" if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 33-35 1. Define the Poincare ball ^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 35-44 .. code-block:: Python from hypll.manifolds.poincare_ball import Curvature, PoincareBall # Making the curvature a learnable parameter is usually suboptimal but can # make training smoother. An initial curvature of 0.1 has also been shown # to help during training. manifold = PoincareBall(c=Curvature(value=0.1, requires_grad=True)) .. GENERATED FROM PYTHON SOURCE LINES 45-47 2. Load and normalize CIFAR10 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 47-51 .. code-block:: Python import torchvision import torchvision.transforms as transforms .. GENERATED FROM PYTHON SOURCE LINES 52-55 .. note:: If running on Windows and you get a BrokenPipeError, try setting the num_worker of torch.utils.data.DataLoader() to 0. .. GENERATED FROM PYTHON SOURCE LINES 55-80 .. code-block:: Python transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) batch_size = 128 trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform ) trainloader = torch.utils.data.DataLoader( trainset, batch_size=batch_size, shuffle=True, num_workers=2 ) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform ) testloader = torch.utils.data.DataLoader( testset, batch_size=batch_size, shuffle=False, num_workers=2 ) .. GENERATED FROM PYTHON SOURCE LINES 81-88 3. Define a Poincare ResNet ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This implementation is based on the Poincare ResNet paper, which can be found at https://arxiv.org/abs/2303.14027 and which, in turn, is based on the original Euclidean implementation described in the paper Deep Residual Learning for Image Recognition by He et al. from 2015: https://arxiv.org/abs/1512.03385. .. GENERATED FROM PYTHON SOURCE LINES 88-258 .. code-block:: Python from typing import Optional from torch import nn from hypll import nn as hnn from hypll.tensors import ManifoldTensor class PoincareResidualBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, manifold: PoincareBall, stride: int = 1, downsample: Optional[nn.Sequential] = None, ): # We can replace each operation in the usual ResidualBlock by a manifold-agnostic # operation and supply the PoincareBall object to these operations. super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.manifold = manifold self.stride = stride self.downsample = downsample self.conv1 = hnn.HConvolution2d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, manifold=manifold, stride=stride, padding=1, ) self.bn1 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold) self.relu = hnn.HReLU(manifold=self.manifold) self.conv2 = hnn.HConvolution2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, manifold=manifold, padding=1, ) self.bn2 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold) def forward(self, x: ManifoldTensor) -> ManifoldTensor: residual = x x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) if self.downsample is not None: residual = self.downsample(residual) # We replace the addition operation inside the skip connection by a Mobius addition. x = self.manifold.mobius_add(x, residual) x = self.relu(x) return x class PoincareResNet(nn.Module): def __init__( self, channel_sizes: list[int], group_depths: list[int], manifold: PoincareBall, ): # For the Poincare ResNet itself we again replace each layer by a manifold-agnostic one # and supply the PoincareBall to each of these. We also replace the ResidualBlocks by # the manifold-agnostic one defined above. super().__init__() self.channel_sizes = channel_sizes self.group_depths = group_depths self.manifold = manifold self.conv = hnn.HConvolution2d( in_channels=3, out_channels=channel_sizes[0], kernel_size=3, manifold=manifold, padding=1, ) self.bn = hnn.HBatchNorm2d(features=channel_sizes[0], manifold=manifold) self.relu = hnn.HReLU(manifold=manifold) self.group1 = self._make_group( in_channels=channel_sizes[0], out_channels=channel_sizes[0], depth=group_depths[0], ) self.group2 = self._make_group( in_channels=channel_sizes[0], out_channels=channel_sizes[1], depth=group_depths[1], stride=2, ) self.group3 = self._make_group( in_channels=channel_sizes[1], out_channels=channel_sizes[2], depth=group_depths[2], stride=2, ) self.avg_pool = hnn.HAvgPool2d(kernel_size=8, manifold=manifold) self.fc = hnn.HLinear(in_features=channel_sizes[2], out_features=10, manifold=manifold) def forward(self, x: ManifoldTensor) -> ManifoldTensor: x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.group1(x) x = self.group2(x) x = self.group3(x) x = self.avg_pool(x) x = self.fc(x.squeeze()) return x def _make_group( self, in_channels: int, out_channels: int, depth: int, stride: int = 1, ) -> nn.Sequential: if stride == 1: downsample = None else: downsample = hnn.HConvolution2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, manifold=self.manifold, stride=stride, ) layers = [ PoincareResidualBlock( in_channels=in_channels, out_channels=out_channels, manifold=self.manifold, stride=stride, downsample=downsample, ) ] for _ in range(1, depth): layers.append( PoincareResidualBlock( in_channels=out_channels, out_channels=out_channels, manifold=self.manifold, ) ) return nn.Sequential(*layers) # Now, let's create a thin Poincare ResNet with channel sizes [4, 8, 16] and with a depth of 20 # layers. net = PoincareResNet( channel_sizes=[4, 8, 16], group_depths=[3, 3, 3], manifold=manifold, ).to(device) .. GENERATED FROM PYTHON SOURCE LINES 259-262 4. Define a Loss function and optimizer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Let's use a Classification Cross-Entropy loss and RiemannianAdam optimizer. .. GENERATED FROM PYTHON SOURCE LINES 262-270 .. code-block:: Python criterion = nn.CrossEntropyLoss() # net.parameters() includes the learnable curvature "c" of the manifold. from hypll.optim import RiemannianAdam optimizer = RiemannianAdam(net.parameters(), lr=0.001) .. GENERATED FROM PYTHON SOURCE LINES 271-276 5. Train the network ^^^^^^^^^^^^^^^^^^^^ We simply have to loop over our data iterator, project the inputs onto the manifold, and feed them to the network and optimize. We will train for a limited number of epochs here due to the long training time of this model. .. GENERATED FROM PYTHON SOURCE LINES 276-306 .. code-block:: Python from hypll.tensors import TangentTensor for epoch in range(2): # Increase this number to at least 100 for good results running_loss = 0.0 for i, data in enumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data[0].to(device), data[1].to(device) # move the inputs to the manifold tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold) manifold_inputs = manifold.expmap(tangents) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(manifold_inputs) loss = criterion(outputs.tensor, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() print(f"[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}") running_loss = 0.0 print("Finished Training") .. GENERATED FROM PYTHON SOURCE LINES 307-311 6. Test the network on the test data ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Let us look at how the network performs on the whole dataset. .. GENERATED FROM PYTHON SOURCE LINES 311-331 .. code-block:: Python correct = 0 total = 0 # since we're not training, we don't need to calculate the gradients for our outputs with torch.no_grad(): for data in testloader: images, labels = data[0].to(device), data[1].to(device) # move the images to the manifold tangents = TangentTensor(data=images, man_dim=1, manifold=manifold) manifold_images = manifold.expmap(tangents) # calculate outputs by running images through the network outputs = net(manifold_images) # the class with the highest energy is what we choose as prediction _, predicted = torch.max(outputs.tensor, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") .. _sphx_glr_download_tutorials_cifar10_resnet_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: cifar10_resnet_tutorial.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: cifar10_resnet_tutorial.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_