Get Started: Register MNIST 5s

The code in this tutorial is available as a colab notebook

To learn to use ICON, we will train a model that registers pictures of the number 5 from the MNIST dataset.

First create and activate some sort of virtual environment, and install torch into it going by the official instructions

Next, install icon_registration

$ pip install icon_registration

Next, create a git repo and import the code we need from icon.

import icon_registration as icon
import icon_registration.data
import icon_registration.networks as networks
from icon_registration.config import device

import numpy as np
import torch
import torchvision.utils
import matplotlib.pyplot as plt

The function icon_registration.data.get_dataset_mnist() will give as a dataset of digits from MNIST filtered to only one number: this is useful for registration experiments, since there is a correct semantic map between two pictures of 5, but not between 5 and 8.

ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=5)

sample_batch = next(iter(ds))[0]
plt.imshow(torchvision.utils.make_grid(sample_batch[:12], nrow=4)[0])

(png, hires.png, pdf)

_images/register_fives-3.png

Next, we want to create a network to train. This will be a tree of RegistrationModule : These are basically the lego blocks for building registration models in ICON.

inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))

for _ in range(3):
     inner_net = icon.TwoStepRegistration(
         icon.DownsampleRegistration(inner_net, dimension=2),
         icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))
     )

net = icon.GradientICON(inner_net, icon.LNCC(sigma=4), lmbda=.5)

A fundamental concept in ICON is that each icon_registration.network_wrappers.RegistrationModule has an associated identity_map tensor that contains the coordinates of each pixel. We can correctly set these identity maps for the whole tree as follows:

net.assign_identity_map(sample_batch.shape)

Notably, as this walks the tree, children of icon_registration.network_wrappers.DownsampleRegistration have appropriately lower resolution identity_map s.

All that remains is to train the network!

net.train()
net.to(device)

optim = torch.optim.Adam(net.parameters(), lr=0.001)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])

(png, hires.png, pdf)

_images/register_fives-6.png

This training is cheap to run, taking about 1 minute on a GPU, or 10 minutes on CPU: In fact, we train the network on readthedocs’ servers to generate the following figures:

plt.close()

def show(tensor):
    plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
    plt.xticks([])
    plt.yticks([])
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()

(png, hires.png, pdf)

_images/register_fives-7.png