Source code for intensity_normalization.util.coregister

"""Co-register images with ANTsPy
Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: 03 Jun 2021
"""

from __future__ import annotations

__all__ = ["register", "Registrator"]

import argparse
import collections.abc
import logging
import typing

import nibabel as nib
import numpy as np

import intensity_normalization as intnorm
import intensity_normalization.base_cli as intnormcli
import intensity_normalization.typing as intnormt

logger = logging.getLogger(__name__)

try:
    import ants
except ImportError as ants_imp_exn:
    msg = "ANTsPy not installed. Install antspyx to use co-registration."
    raise RuntimeError(msg) from ants_imp_exn

ValidImage = typing.Union[nib.nifti1.Nifti1Image, ants.ANTsImage, intnormt.ImageLike]


def to_ants(image: ValidImage, /) -> ants.ANTsImage:
    if isinstance(image, ants.ANTsImage):
        ants_image = image
    elif isinstance(image, nib.nifti1.Nifti1Image):
        ants_image = ants.from_nibabel(image)
    elif isinstance(image, np.ndarray):
        ants_image = ants.from_numpy(image)
    else:
        msg = "Provided image must be an ANTsImage, Nifti1Image,"
        msg += f" or (a subclass of) np.ndarray. Got '{type(image)}'."
        raise ValueError(msg)
    return ants_image


[docs]def register( image: ValidImage, /, template: typing.Optional[ValidImage] = None, *, type_of_transform: str = "Affine", interpolator: str = "bSpline", metric: str = "mattes", initial_rigid: bool = True, template_mask: typing.Optional[ValidImage] = None, ) -> nib.nifti1.Nifti1Image | ants.ANTsImage: if template is None: standard_mni = ants.get_ants_data("mni") template = ants.image_read(standard_mni) else: template = to_ants(template) is_nibabel = isinstance(image, nib.nifti1.Nifti1Image) image = to_ants(image) if initial_rigid: logger.debug("Doing initial rigid registration.") transforms = ants.registration( fixed=template, moving=image, type_of_transform="Rigid", aff_metric=metric, syn_metric=metric, ) rigid_transform = transforms["fwdtransforms"][0] else: rigid_transform = None logger.debug(f"Doing {type_of_transform} registration.") transform = ants.registration( fixed=template, moving=image, initial_transform=rigid_transform, type_of_transform=type_of_transform, mask=template_mask, aff_metric=metric, syn_metric=metric, )["fwdtransforms"] logger.debug("Applying transformations.") registered = ants.apply_transforms( template, image, transform, interpolator=interpolator, ) return registered.to_nibabel() if is_nibabel else registered
[docs]class Registrator(intnormcli.SingleImageCLI): def __init__( self, template: nib.nifti1.Nifti1Image | ants.ANTsImage = None, *, type_of_transform: str = "Affine", interpolator: str = "bSpline", metric: str = "mattes", initial_rigid: bool = True, ): super().__init__() if template is None: logger.info("Using MNI (in RAS orientation) as template.") standard_mni = ants.get_ants_data("mni") self.template = ants.image_read(standard_mni).reorient_image2("RAS") else: logger.debug("Loading template.") self.template = ants.from_nibabel(template) self.type_of_transform = type_of_transform self.interpolator = interpolator self.metric = metric self.initial_rigid = initial_rigid def __call__( self, image: nib.nifti1.Nifti1Image | ants.ANTsImage, /, *args: typing.Any, **kwargs: typing.Any, ) -> nib.nifti1.Nifti1Image | ants.ANTsImage: return register( image, template=self.template, type_of_transform=self.type_of_transform, interpolator=self.interpolator, metric=self.metric, initial_rigid=self.initial_rigid, )
[docs] def register_images( self, images: collections.abc.Sequence[nib.nifti1.Nifti1Image | ants.ANTsImage], /, ) -> collections.abc.Sequence[nib.nifti1.Nifti1Image | ants.ANTsImage]: return [self(image) for image in images]
[docs] def register_images_to_templates( self, images: collections.abc.Sequence[nib.nifti1.Nifti1Image | ants.ANTsImage], /, *, templates: collections.abc.Sequence[nib.nifti1.Nifti1Image | ants.ANTsImage], ) -> collections.abc.Sequence[nib.nifti1.Nifti1Image | ants.ANTsImage]: assert len(images) == len(templates) registered = [] original_template = self.template for image, template in zip(images, templates): self.template = template registered.append(self(image)) self.template = original_template return registered
[docs] @staticmethod def name() -> str: return "registered"
[docs] @staticmethod def fullname() -> str: return Registrator.name()
[docs] @staticmethod def description() -> str: return "Co-register an image to MNI or another image."
[docs] @classmethod def get_parent_parser( cls, desc: str, valid_modalities: frozenset[str] = intnorm.VALID_MODALITIES, **kwargs: typing.Any, ) -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description=desc, formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "image", type=intnormt.file_path(), help="Path of image to normalize.", ) parser.add_argument( "-t", "--template", type=intnormt.file_path(), default=None, help="Path of target for registration.", ) parser.add_argument( "-o", "--output", type=intnormt.save_file_path(), default=None, help="Path to save registered image.", ) parser.add_argument( "-tot", "--type-of-transform", type=str, default="Affine", choices=intnormt.allowed_transforms, help="Type of registration transform to perform.", metavar="", # avoid printing massive list of choices ) parser.add_argument( "-i", "--interpolator", type=str, default="bSpline", choices=intnormt.allowed_interpolators, help="Type of interpolator to use.", metavar="", ) parser.add_argument( "-mc", "--metric", type=str, default="mattes", choices=intnormt.allowed_metrics, help="Metric to use for registration loss function.", metavar="", ) parser.add_argument( "-ir", "--initial-rigid", action="store_true", help="Do a rigid registration before doing " "the `type_of_transform` registration.", ) parser.add_argument( "-v", "--verbosity", action="count", default=0, help="Increase output verbosity (e.g., -vv is more than -v).", ) parser.add_argument( "--version", action="store_true", help="Print the version of intensity-normalization.", ) return parser
[docs] @classmethod def from_argparse_args(cls, args: argparse.Namespace) -> Registrator: if args.template is not None: args.template = ants.image_read(args.template) return cls( template=args.template, type_of_transform=args.type_of_transform, interpolator=args.interpolator, metric=args.metric, initial_rigid=args.initial_rigid, )
[docs] @staticmethod def load_image(image_path: intnormt.PathLike) -> ants.ANTsImage: return ants.image_read(image_path)