.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/cifar10_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_cifar10_tutorial.py: Training a Hyperbolic Classifier ================================ This is an adaptation of torchvision's tutorial "Training a Classifier" to hyperbolic space. The original tutorial can be found here: - https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html Training a Hyperbolic Image Classifier -------------------------------------- We will do the following steps in order: 1. Define a hyperbolic manifold 2. Load and normalize the CIFAR10 training and test datasets using ``torchvision`` 3. Define a hyperbolic Convolutional Neural Network 4. Define a loss function and optimizer 5. Train the network on the training data 6. Test the network on the test data .. GENERATED FROM PYTHON SOURCE LINES 26-29 1. Define a hyperbolic manifold ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ We use the Poincaré ball model for the purposes of this tutorial. .. GENERATED FROM PYTHON SOURCE LINES 29-38 .. code-block:: Python from hypll.manifolds.poincare_ball import Curvature, PoincareBall # Making the curvature a learnable parameter is usually suboptimal but can # make training smoother. manifold = PoincareBall(c=Curvature(requires_grad=True)) .. GENERATED FROM PYTHON SOURCE LINES 39-41 2. Load and normalize CIFAR10 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 41-46 .. code-block:: Python import torch import torchvision import torchvision.transforms as transforms .. GENERATED FROM PYTHON SOURCE LINES 47-50 .. note:: If running on Windows and you get a BrokenPipeError, try setting the num_worker of torch.utils.data.DataLoader() to 0. .. GENERATED FROM PYTHON SOURCE LINES 50-78 .. code-block:: Python transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) batch_size = 4 trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform ) trainloader = torch.utils.data.DataLoader( trainset, batch_size=batch_size, shuffle=True, num_workers=2 ) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform ) testloader = torch.utils.data.DataLoader( testset, batch_size=batch_size, shuffle=False, num_workers=2 ) classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") .. GENERATED FROM PYTHON SOURCE LINES 79-83 3. Define a hyperbolic Convolutional Neural Network ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Let's rebuild the convolutional neural network from torchvision's tutorial using hyperbolic modules. .. GENERATED FROM PYTHON SOURCE LINES 83-116 .. code-block:: Python from torch import nn from hypll import nn as hnn class Net(nn.Module): def __init__(self): super().__init__() self.conv1 = hnn.HConvolution2d( in_channels=3, out_channels=6, kernel_size=5, manifold=manifold ) self.pool = hnn.HMaxPool2d(kernel_size=2, manifold=manifold, stride=2) self.conv2 = hnn.HConvolution2d( in_channels=6, out_channels=16, kernel_size=5, manifold=manifold ) self.fc1 = hnn.HLinear(in_features=16 * 5 * 5, out_features=120, manifold=manifold) self.fc2 = hnn.HLinear(in_features=120, out_features=84, manifold=manifold) self.fc3 = hnn.HLinear(in_features=84, out_features=10, manifold=manifold) self.relu = hnn.HReLU(manifold=manifold) def forward(self, x): x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) x = x.flatten(start_dim=1) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() .. GENERATED FROM PYTHON SOURCE LINES 117-122 4. Define a Loss function and optimizer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Let's use a Classification Cross-Entropy loss and RiemannianAdam optimizer. Adam is preferred because hyperbolic linear layers can sometimes have training difficulties early on due to poor initialization. .. GENERATED FROM PYTHON SOURCE LINES 122-130 .. code-block:: Python criterion = nn.CrossEntropyLoss() # net.parameters() includes the learnable curvature "c" of the manifold. from hypll.optim import RiemannianAdam optimizer = RiemannianAdam(net.parameters(), lr=0.001) .. GENERATED FROM PYTHON SOURCE LINES 131-136 5. Train the network ^^^^^^^^^^^^^^^^^^^^ This is when things start to get interesting. We simply have to loop over our data iterator, project the inputs onto the manifold, and feed them to the network and optimize. .. GENERATED FROM PYTHON SOURCE LINES 136-167 .. code-block:: Python from hypll.tensors import TangentTensor for epoch in range(2): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data # move the inputs to the manifold tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold) manifold_inputs = manifold.expmap(tangents) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(manifold_inputs) loss = criterion(outputs.tensor, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") running_loss = 0.0 print("Finished Training") .. GENERATED FROM PYTHON SOURCE LINES 168-169 Let's quickly save our trained model: .. GENERATED FROM PYTHON SOURCE LINES 169-174 .. code-block:: Python PATH = "./cifar_net.pth" torch.save(net.state_dict(), PATH) .. GENERATED FROM PYTHON SOURCE LINES 175-177 Next, let's load back in our saved model (note: saving and re-loading the model wasn't necessary here, we only did it to illustrate how to do so): .. GENERATED FROM PYTHON SOURCE LINES 177-181 .. code-block:: Python net = Net() net.load_state_dict(torch.load(PATH)) .. GENERATED FROM PYTHON SOURCE LINES 182-186 6. Test the network on the test data ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Let us look at how the network performs on the whole dataset. .. GENERATED FROM PYTHON SOURCE LINES 186-208 .. code-block:: Python correct = 0 total = 0 # since we're not training, we don't need to calculate the gradients for our outputs with torch.no_grad(): for data in testloader: images, labels = data # move the images to the manifold tangents = TangentTensor(data=images, man_dim=1, manifold=manifold) manifold_images = manifold.expmap(tangents) # calculate outputs by running images through the network outputs = net(manifold_images) # the class with the highest energy is what we choose as prediction _, predicted = torch.max(outputs.tensor, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") .. GENERATED FROM PYTHON SOURCE LINES 209-215 That looks way better than chance, which is 10% accuracy (randomly picking a class out of 10 classes). Seems like the network learnt something. Hmmm, what are the classes that performed well, and the classes that did not perform well: .. GENERATED FROM PYTHON SOURCE LINES 215-243 .. code-block:: Python # prepare to count predictions for each class correct_pred = {classname: 0 for classname in classes} total_pred = {classname: 0 for classname in classes} # again no gradients needed with torch.no_grad(): for data in testloader: images, labels = data # move the images to the manifold tangents = TangentTensor(data=images, man_dim=1, manifold=manifold) manifold_images = manifold.expmap(tangents) outputs = net(manifold_images) _, predictions = torch.max(outputs.tensor, 1) # collect the correct predictions for each class for label, prediction in zip(labels, predictions): if label == prediction: correct_pred[classes[label]] += 1 total_pred[classes[label]] += 1 # print accuracy for each class for classname, correct_count in correct_pred.items(): accuracy = 100 * float(correct_count) / total_pred[classname] print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %") .. GENERATED FROM PYTHON SOURCE LINES 244-286 Training on GPU ---------------- Just like how you transfer a Tensor onto the GPU, you transfer the neural net onto the GPU. Let's first define our device as the first visible cuda device if we have CUDA available: .. code:: python device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') Assuming that we are on a CUDA machine, this should print a CUDA device: .. code:: python print(device) The rest of this section assumes that ``device`` is a CUDA device. Then these methods will recursively go over all modules and convert their parameters and buffers to CUDA tensors: .. code:: python net.to(device) Remember that you will have to send the inputs and targets at every step to the GPU too: .. code:: python inputs, labels = data[0].to(device), data[1].to(device) **Goals achieved**: - Train a small hyperbolic neural network to classify images. .. _sphx_glr_download_tutorials_cifar10_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: cifar10_tutorial.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: cifar10_tutorial.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_