Extending ICON

Writing a similarity measure

Similarity measures in icon_registration are python functions that take two pytorch tensors representing batches of images. These pytorch tensors have two channels: the first channel is image intensity.

The second channel is 1 if the intensity at that voxel is interpolated, or zero otherwise. For some types of images, it is useful to disregard image similarity in extrapolated regions of a warped image. For images with a black background such as skull-stripped brains, this is not necessary.

We will implement a simple absolute difference similarity metric

import torch

def absolute_similarity(image_A, image_B):
     # since we are not using the interpolation information, we strip it off before computing similarity.
    image_A, image_B = image_A[:, 0], image_B[:, 0]

    return torch.mean(torch.abs(image_A - image_B))

Writing a RegistrationModule

Representing a Transform

In order to write a registration method, we first have to understand how icon_registration represents a transform. In icon_registration, a transform is any python function that takes as input a tensor of coordinates, and returns those coordinates transformed.

This is intended to closely conform to the mathematical notion of a transform: a function \(\phi: \mathbb{R}^N \rightarrow \mathbb{R}^N\)

This is an extremely flexible definition, since the transform controls how its internal parameters are used to warp coordinates. On the flip side, it means that transforms are responsible for knowing how to interpret their parameters.

The convention in icon is to store the parameters of the transform in the closure of the function; however an object oriented approach can also work. We will cover both options in this tutorial.

Subclassing RegistrationModule

Now that we know how a transform is represented, we can write a new registration technique by subclassing RegistrationModule: a registration method is simply any such subclass that takes two images as input to its forward method and returns a transform.

The registration methods available in icon_registration are defined in network_wrappers

They are named like FunctionFromMatrix or FunctionFromVectorField because they wrap an internal pytorch module that produces, for example, a matrix, from two images, and turn that matrix into a transform, ie a function.

So! Let us write a toy, euler integration based SVF registration.

Registration methods in icon_registration are subclasses of RegistrationModule with the following api: The forward method of a registration method take in two tensors representing batches of images, image_A and image_B, and returns a python function.

This corresponds to a mathematical notion of a registration method as an operator from a pair of images to a function:

\[\Phi: (( \mathbb{R}^N \rightarrow \mathbb{R}) \times (\mathbb{R}^N \rightarrow \mathbb{R})) \rightarrow (\mathbb{R}^N \rightarrow \mathbb{R}^N)\]

When we construct our FunctionFromStationaryVelocityField object, we will pass in a pytorch module net that is responsible for generating the velocity field from the input images. This way, it is easy experiment with different architectures.

Let’s make a simple implementation of a stationary velocity field, by starting at the identity map, and then iterating

\[x := x + \frac{1}{N} v(x)\]

Here the utility function icon_registration.network_wrappers.RegistrationModule.as_function() is useful for reinterpreting a tensor (in this case a velocity field) as a function \(v\) with the domain of coordinates.

Stationary velocity field

import icon_registration as icon

class FunctionFromStationaryVelocityField(icon.RegistrationModule):
   def __init__(self, net, n_steps=16):
       super().__init__()
       self.net = net
       self.n_steps = n_steps

   def forward(self, image_A, image_B):
       velocityfield_delta = self.net(image_A, image_B) / self.n_steps
       def transform(coordinate_tensor):
           for _ in range(self.n_steps):
             coordinate_tensor = coordinate_tensor + self.as_function(velocityfield_delta)(coordinate_tensor)
           return coordinate_tensor
       return transform

Building a registration network

These components can now be mixed and matched with existing icon_registration components. For example, if we want to perform a two step registration, with coarse affine registration followed by fine registration using our custom stationary velocity field, and we want to use our custom absolute difference similarity measure, we write

from icon_registration import networks

registration_network = icon.GradientICON(
    icon.TwoStepRegistration(
        icon.FunctionFromMatrix(networks.ConvolutionalMatrixNet(dimension=2)),
        FunctionFromStationaryVelocityField(
             networks.tallUNet2()
        )
    ),
    absolute_similarity,
    lmbda=.4
)