Source code for intensity_normalization.normalize.lsq

"""Least-squares fit tissue means of a set of images."""

from __future__ import annotations

__all__ = ["LeastSquaresNormalize"]

import argparse
import collections.abc
import logging
import pathlib
import typing

import numpy as np
import numpy.typing as npt
import pymedio.image as mioi

import intensity_normalization as intnorm
import intensity_normalization.errors as intnorme
import intensity_normalization.normalize.base as intnormb
import intensity_normalization.typing as intnormt
import intensity_normalization.util.io as intnormio
import intensity_normalization.util.tissue_membership as intnormtm

logger = logging.getLogger(__name__)

S = typing.TypeVar("S", bound=intnormt.ImageLike)


[docs] class LeastSquaresNormalize( intnormb.LocationScaleCLIMixin, intnormb.DirectoryNormalizeCLI ): def __init__(self, *, norm_value: float = 1.0, **kwargs: typing.Any): """Minimize the distance tissue means in a set of images via least-squares""" super().__init__(norm_value=norm_value, **kwargs) self.tissue_memberships: list[mioi.Image] = [] self.standard_tissue_means: npt.NDArray | None = None
[docs] def calculate_location( self, image: intnormt.ImageLike, /, mask: intnormt.ImageLike | None = None, *, modality: intnormt.Modality = intnormt.Modality.T1, ) -> float: return 0.0
[docs] def calculate_scale( self, image: intnormt.ImageLike, /, mask: intnormt.ImageLike | None = None, *, modality: intnormt.Modality = intnormt.Modality.T1, ) -> float: tissue_membership: intnormt.ImageLike if modality == intnormt.Modality.T1: tissue_membership = intnormtm.find_tissue_memberships(image, mask) self.tissue_memberships.append(tissue_membership) elif mask is not None: tissue_membership = self._fix_tissue_membership(image, mask) else: msg = "If 'modality' != 't1', you must provide a " msg += "tissue membership array in the mask argument." raise ValueError(msg) tissue_means = self.tissue_means(image, tissue_membership) sf = self.scaling_factor(tissue_means) return sf
def _fit( self, images: collections.abc.Sequence[intnormt.ImageLike], /, masks: collections.abc.Sequence[intnormt.ImageLike] | None = None, *, modality: intnormt.Modality = intnormt.Modality.T1, **kwargs: typing.Any, ) -> None: image = images[0] # only need one image to fit this method mask = masks[0] if masks is not None else None tissue_membership: intnormt.ImageLike if not isinstance(mask, np.ndarray) and mask is not None: raise ValueError("Mask must be either none or be like a numpy array.") if modality == intnormt.Modality.T1: tissue_membership = intnormtm.find_tissue_memberships(image, mask) elif mask is not None: logger.debug("Assuming 'masks' contains tissue memberships.") tissue_membership = self._fix_tissue_membership(image, mask) else: msg = "If 'modality' != 't1', you must provide a " msg += "tissue membership array in the mask argument." raise ValueError(msg) csf_mean = np.average(image, weights=tissue_membership[..., 0]) norm_image: intnormt.ImageLike = (image / csf_mean) * self.norm_value self.standard_tissue_means = self.tissue_means( norm_image, tissue_membership, ) def _fix_tissue_membership( self, image: intnormt.ImageLike, tissue_membership: S ) -> S: image_ndim = int(image.ndim) tm_ndim = int(tissue_membership.ndim) if tissue_membership.shape[:image_ndim] != image.shape and tm_ndim == 4: # try to swap last axes b/c sitk, if still doesn't match then fail tissue_membership = tissue_membership.transpose(3, 0, 1, 2) if tissue_membership.shape[:image_ndim] != image.shape: msg = "If masks provided, need to have same spatial shape as image." raise intnorme.NormalizationError(msg) return tissue_membership
[docs] @staticmethod def tissue_means( image: intnormt.ImageLike, /, tissue_membership: intnormt.ImageLike ) -> npt.NDArray: n_tissues = tissue_membership.shape[-1] weighted_avgs = [ np.average(image, weights=tissue_membership[..., i]) for i in range(n_tissues) ] return np.asarray([weighted_avgs]).T
[docs] def scaling_factor(self, tissue_means: npt.NDArray) -> float: numerator = tissue_means.T @ tissue_means denominator = tissue_means.T @ self.standard_tissue_means sf: float = (numerator / denominator).item() return sf
[docs] @staticmethod def name() -> str: return "lsq"
[docs] @staticmethod def fullname() -> str: return "Least Squares"
[docs] @staticmethod def description() -> str: desc = "Minimize distance between tissue means (CSF/GM/WM) in a " desc += "least squares-sense within a set of MR images." return desc
[docs] def save_additional_info( self, args: argparse.Namespace, **kwargs: typing.Any, ) -> None: normed = kwargs["normalized"] image_fns = kwargs["image_filenames"] if not self.tissue_memberships: logger.debug("'tissue_memberships' empty. Skipping saving.") return if len(self.tissue_memberships) != len(image_fns): msg = f"'tissue_memberships' ({len(self.tissue_memberships)}) " msg += f"and 'image_filenames' ({len(image_fns)}) " msg += "must be in correspondence." raise RuntimeError(msg) if len(self.tissue_memberships) != len(normed): msg = f"'tissue_memberships' ({len(self.tissue_memberships)}) " msg += f"and 'normalized' ({len(normed)}) " msg += "must be in correspondence." raise RuntimeError(msg) for memberships, norm, fn in zip(self.tissue_memberships, normed, image_fns): if hasattr(norm, "affine"): tissue_memberships: mioi.Image = mioi.Image(memberships, norm.affine) elif hasattr(memberships, "affine"): tissue_memberships = mioi.Image(memberships, memberships.affine) else: tissue_memberships = mioi.Image(memberships, None) base, name, ext = intnormio.split_filename(fn) new_name = name + "_tissue_memberships" + ext if args.output_dir is None: output = base / new_name else: output = pathlib.Path(args.output_dir) / new_name tissue_memberships.to_filename(output) del self.tissue_memberships if args.save_standard_tissue_means is not None: self.save_standard_tissue_means(args.save_standard_tissue_means)
[docs] def save_standard_tissue_means(self, filename: intnormt.PathLike, /) -> None: if self.standard_tissue_means is None: msg = "Fit required before saving standard tissue means." raise intnorme.NormalizationError(msg) np.save(filename, self.standard_tissue_means)
[docs] def load_standard_tissue_means(self, filename: intnormt.PathLike, /) -> None: data = np.load(filename) self.standard_tissue_means = data
[docs] @classmethod def from_argparse_args(cls, args: argparse.Namespace, /) -> LeastSquaresNormalize: out = cls(norm_value=args.norm_value) return out
[docs] def call_from_argparse_args( self, args: argparse.Namespace, /, **kwargs: typing.Any ) -> None: if args.load_standard_tissue_means is not None: self.load_standard_tissue_means(args.load_standard_tissue_means) self.fit = lambda *args, **kwargs: None # type: ignore[method-assign] args.modality = intnormt.Modality.from_string(args.modality) use_masks = True if args.mask_dir is not None: if args.modality != intnormt.Modality.T1: msg = f"If brain masks provided, 'modality' must be 't1'. Got '{args.modality}'." # noqa: E501 raise ValueError(msg) elif args.tissue_membership_dir is not None: use_masks = False args.mask_dir = args.tissue_membership_dir super().call_from_argparse_args(args, use_masks_in_plot=use_masks)
[docs] @classmethod def get_parent_parser( cls, desc: str, valid_modalities: frozenset[str] = intnorm.VALID_MODALITIES, **kwargs: typing.Any, ) -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description=desc, formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "image_dir", type=intnormt.dir_path(), help="Path of directory containing images to normalize.", ) parser.add_argument( "-o", "--output-dir", type=intnormt.dir_path(), default=None, help="Path of directory in which to save normalized images.", ) parser.add_argument( "-mo", "--modality", type=str, default="t1", choices=intnorm.VALID_MODALITIES, help="Modality of the images.", ) parser.add_argument( "-n", "--norm-value", type=intnormt.positive_float(), default=1.0, help="Reference value for normalization.", ) parser.add_argument( "-e", "--extension", type=str, default="nii*", help="Extension of images (must be nibabel readable).", ) parser.add_argument( "-p", "--plot-histogram", action="store_true", help="Plot the histogram of the normalized image.", ) parser.add_argument( "-v", "--verbosity", action="count", default=0, help="Increase output verbosity (e.g., -vv is more than -v).", ) parser.add_argument( "--version", action="store_true", help="Print the version of intensity-normalization.", ) return parser
[docs] @staticmethod def add_method_specific_arguments( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: parser = parent_parser.add_argument_group("method-specific arguments") parser.add_argument( "-sstm", "--save-standard-tissue-means", default=None, type=intnormt.save_file_path(), help="Save the standard tissue means fit by the method.", ) parser.add_argument( "-lstm", "--load-standard-tissue-means", default=None, type=intnormt.file_path(), help="Load a standard tissue means previously fit by the method.", ) exclusive = parent_parser.add_argument_group( "mutually exclusive optional arguments" ) group = exclusive.add_mutually_exclusive_group(required=False) group.add_argument( "-m", "--mask-dir", type=intnormt.dir_path(), default=None, help="Path to a foreground mask for the image. " "Provide this if not providing a tissue mask " "(if image is not skull-stripped).", ) group.add_argument( "-tm", "--tissue-membership-dir", type=intnormt.dir_path(), help="Path to a mask of a tissue memberships. " "Provide this if not providing the foreground mask.", ) return parent_parser