Source code for intensity_normalization.typing

"""Project-specific types
Author: Jacob Reinhold <jcreinhold@gmail.com>
Created on: 01 Jun 2021
"""

from __future__ import annotations

__all__ = [
    "allowed_interpolators",
    "allowed_metrics",
    "allowed_orientations",
    "allowed_transforms",
    "ArgType",
    "dir_path",
    "file_path",
    "ImageLike",
    "interp_type_dict",
    "Modality",
    "new_parse_type",
    "nonnegative_float",
    "nonnegative_int",
    "PathLike",
    "positive_float",
    "positive_int",
    "positive_int_or_none",
    "positive_odd_int_or_none",
    "probability_float",
    "probability_float_or_none",
    "save_file_path",
    "SplitFilename",
    "TissueType",
]

import argparse
import collections.abc
import enum
import os
import pathlib
import typing

import numpy as np
import numpy.typing as npt

import intensity_normalization as intnorm

ArgType = typing.Union[argparse.Namespace, list[str], None]
PathLike = typing.Union[str, os.PathLike]
ShapeLike = typing.Union[
    typing.SupportsIndex, collections.abc.Sequence[typing.SupportsIndex]
]

_MODALITIES = [(vm.upper(), vm) for vm in sorted(intnorm.VALID_MODALITIES)]


[docs]class Modality(enum.Enum): FLAIR: str = "flair" MD: str = "md" OTHER: str = "other" PD: str = "pd" T1: str = "t1" T2: str = "t2"
[docs] @classmethod def from_string(cls: typing.Type, string: str | Modality) -> Modality: if isinstance(string, cls): modality: Modality = string return modality for name, value in _MODALITIES: if string == value: modality = getattr(cls, name) return modality msg = f"'string' must be one of {intnorm.VALID_MODALITIES}. Got '{string}'." raise ValueError(msg)
# not ideal DRY, but avoid functional enum API for better IDE support & flake8 if set(m.value for m in Modality) != set(intnorm.VALID_MODALITIES): raise RuntimeError("Modalities enum out of sync with VALID_MODALITIES.")
[docs]class TissueType(enum.Enum): CSF: str = "csf" GM: str = "gm" WM: str = "wm"
[docs] @classmethod def from_string(cls, string: str) -> TissueType: if string.lower() == "csf": return TissueType.CSF elif string.lower() == "gm": return TissueType.GM elif string.lower() == "wm": return TissueType.WM else: raise ValueError(f"'string' must be 'csf', 'gm', or 'wm'. Got '{string}'.")
[docs] def to_int(self) -> int: if self == TissueType.CSF: return 0 elif self == TissueType.GM: return 1 elif self == TissueType.WM: return 2 else: raise ValueError("Unexpected enum.")
[docs] def to_fullname(self) -> str: if self == TissueType.CSF: return "Cerebrospinal fluid" elif self == TissueType.GM: return "Grey matter" elif self == TissueType.WM: return "White matter" else: raise ValueError("Unexpected enum.")
[docs]class SplitFilename(typing.NamedTuple): path: pathlib.Path base: str ext: str
interp_type_dict = dict( linear=0, nearest_neighbor=1, gaussian=2, windowed_sinc=3, bspline=4, ) # copied from: # https://github.com/ANTsX/ANTsPy/blob/5b4b8273815b681b0542a3dc8846713e2ebb786e/ants/registration/reorient_image.py allowed_orientations = frozenset( { "RIP", "LIP", "RSP", "LSP", "RIA", "LIA", "RSA", "LSA", "IRP", "ILP", "SRP", "SLP", "IRA", "ILA", "SRA", "SLA", "RPI", "LPI", "RAI", "LAI", "RPS", "LPS", "RAS", "LAS", "PRI", "PLI", "ARI", "ALI", "PRS", "PLS", "ARS", "ALS", "IPR", "SPR", "IAR", "SAR", "IPL", "SPL", "IAL", "SAL", "PIR", "PSR", "AIR", "ASR", "PIL", "PSL", "AIL", "ASL", } ) # copied from: # https://github.com/ANTsX/ANTsPy/blob/4474f894d184da98a099cd9c852795c384fa3b8f/ants/registration/interface.py allowed_transforms = frozenset( { "SyNBold", "SyNBoldAff", "ElasticSyN", "Elastic", "SyN", "SyNRA", "SyNOnly", "SyNAggro", "SyNCC", "TRSAA", "SyNabp", "SyNLessAggro", "TV[1]", "TV[2]", "TV[3]", "TV[4]", "TV[5]", "TV[6]", "TV[7]", "TV[8]", "TVMSQ", "TVMSQC", "Rigid", "Similarity", "Translation", "Affine", "AffineFast", "BOLDAffine", "QuickRigid", "DenseRigid", "BOLDRigid", "antsRegistrationSyN[r]", "antsRegistrationSyN[t]", "antsRegistrationSyN[a]", "antsRegistrationSyN[b]", "antsRegistrationSyN[s]", "antsRegistrationSyN[br]", "antsRegistrationSyN[sr]", "antsRegistrationSyN[bo]", "antsRegistrationSyN[so]", "antsRegistrationSyNQuick[r]", "antsRegistrationSyNQuick[t]", "antsRegistrationSyNQuick[a]", "antsRegistrationSyNQuick[b]", "antsRegistrationSyNQuick[s]", "antsRegistrationSyNQuick[br]", "antsRegistrationSyNQuick[sr]", "antsRegistrationSyNQuick[bo]", "antsRegistrationSyNQuick[so]", "antsRegistrationSyNRepro[r]", "antsRegistrationSyNRepro[t]", "antsRegistrationSyNRepro[a]", "antsRegistrationSyNRepro[b]", "antsRegistrationSyNRepro[s]", "antsRegistrationSyNRepro[br]", "antsRegistrationSyNRepro[sr]", "antsRegistrationSyNRepro[bo]", "antsRegistrationSyNRepro[so]", "antsRegistrationSyNQuickRepro[r]", "antsRegistrationSyNQuickRepro[t]", "antsRegistrationSyNQuickRepro[a]", "antsRegistrationSyNQuickRepro[b]", "antsRegistrationSyNQuickRepro[s]", "antsRegistrationSyNQuickRepro[br]", "antsRegistrationSyNQuickRepro[sr]", "antsRegistrationSyNQuickRepro[bo]", "antsRegistrationSyNQuickRepro[so]", } ) # copied from: # https://github.com/ANTsX/ANTsPy/blob/4474f894d184da98a099cd9c852795c384fa3b8f/ants/registration/apply_transforms.py allowed_interpolators = frozenset( { "linear", "nearestNeighbor", "multiLabel", "gaussian", "bSpline", "cosineWindowedSinc", "welchWindowedSinc", "hammingWindowedSinc", "lanczosWindowedSinc", "genericLabel", } ) # copied from: # https://github.com/ANTsX/ANTsPy/blob/f2aec7283d26d914d98e2b440e4d2badff78da38/ants/registration/interface.py allowed_metrics = frozenset( { "CC", "mattes", "meansquares", "demons", } ) def return_none( func: typing.Callable[[typing.Any, typing.Any], typing.Any] ) -> typing.Callable[[typing.Any, typing.Any], typing.Any]: def new_func(self: object, string: typing.Any) -> typing.Any: if string is None: return None elif isinstance(string, str): if string.lower() in ("none", "null"): return None return func(self, string) return new_func class _ParseType: @property def __name__(self) -> str: name = self.__class__.__name__ assert isinstance(name, str) return name def __str__(self) -> str: return self.__name__
[docs]class save_file_path(_ParseType): def __call__(self, string: str) -> pathlib.Path: if not string.isprintable(): msg = f"'{string}' must only contain printable characters." raise argparse.ArgumentTypeError(msg) path = pathlib.Path(string) return path
[docs]class dir_path(_ParseType): def __call__(self, string: str) -> str: path = pathlib.Path(string) if not path.is_dir(): msg = f"'{string}' is not a valid directory path." raise argparse.ArgumentTypeError(msg) return str(path.resolve())
[docs]class file_path(_ParseType): def __call__(self, string: str) -> str: path = pathlib.Path(string) if not path.is_file(): msg = f"'{string}' is not a valid file path." raise argparse.ArgumentTypeError(msg) return str(path)
[docs]class positive_float(_ParseType): def __call__(self, string: str) -> float: num = float(string) if num <= 0.0: msg = f"'{string}' needs to be a positive float." raise argparse.ArgumentTypeError(msg) return num
[docs]class positive_int(_ParseType): def __call__(self, string: str) -> int: num = int(string) if num <= 0: msg = f"'{string}' needs to be a positive integer." raise argparse.ArgumentTypeError(msg) return num
[docs]class positive_odd_int_or_none(_ParseType): @return_none def __call__(self, string: str) -> int | None: num = int(string) if num <= 0 or not (num % 2): msg = f"'{string}' needs to be a positive odd integer." raise argparse.ArgumentTypeError(msg) return num
[docs]class positive_int_or_none(_ParseType): @return_none def __call__(self, string: str) -> int | None: return positive_int()(string)
[docs]class nonnegative_int(_ParseType): def __call__(self, string: str) -> int: num = int(string) if num < 0: msg = f"'{string}' needs to be a non-negative integer." raise argparse.ArgumentTypeError(msg) return num
[docs]class nonnegative_float(_ParseType): def __call__(self, string: str) -> float: num = float(string) if num < 0.0: msg = f"'{string}' needs to be a non-negative float." raise argparse.ArgumentTypeError(msg) return num
[docs]class probability_float(_ParseType): def __call__(self, string: str) -> float: num = float(string) if num < 0.0 or num > 1.0: msg = f"'{string}' needs to be between 0 and 1." raise argparse.ArgumentTypeError(msg) return num
[docs]class probability_float_or_none(_ParseType): @return_none def __call__(self, string: str) -> float | None: return probability_float()(string)
class NewParseType: def __init__(self, func: typing.Callable[[typing.Any], typing.Any], name: str): self.name = name self.func = func def __str__(self) -> str: return self.name def __call__(self, val: typing.Any) -> typing.Any: return self.func(val)
[docs]def new_parse_type( func: typing.Callable[[typing.Any], typing.Any], name: str ) -> NewParseType: return NewParseType(func, name)
S_co = typing.TypeVar("S_co", bound="ImageLike", covariant=True) T_co = typing.TypeVar("T_co", bound="ImageLike", covariant=True) U_co = typing.TypeVar("U_co", bound="ImageLike", covariant=True) NBit = typing.TypeVar("NBit", bound=npt.NBitBase) Float = typing.Union[np.floating[NBit], float] Int = typing.Union[np.integer[NBit], int]
[docs]class ImageLike(typing.Protocol[S_co, T_co, U_co]): """support anything that implements the methods here""" def __gt__(self: T_co, other: typing.Any) -> U_co: ... def __ge__(self: T_co, other: typing.Any) -> U_co: ... def __lt__(self: T_co, other: typing.Any) -> U_co: ... def __le__(self: T_co, other: typing.Any) -> U_co: ... def __and__(self: T_co, other: typing.Any) -> U_co: ... def __or__(self: T_co, other: typing.Any) -> U_co: ... def __add__(self: T_co, other: typing.Any) -> S_co: ... def __sub__(self: T_co, other: typing.Any) -> S_co: ... def __mul__(self: T_co, other: typing.Any) -> S_co: ... def __truediv__(self: T_co, other: typing.Any) -> S_co: ... def __getitem__(self: T_co, item: typing.Any) -> typing.Any: ... def __iter__(self: T_co) -> T_co: ... def __array__(self) -> npt.NDArray: ...
[docs] def sum(self) -> Float | Int: ...
@property def ndim(self) -> Int: ...
[docs] def any( self, axis: int | tuple[int, ...] | None = None, ) -> typing.Any: ...
[docs] def nonzero(self) -> typing.Any: ...
[docs] def squeeze(self) -> typing.Any: ...
@property def shape(self) -> tuple[int, ...]: ...
[docs] def mean(self) -> float: ...
[docs] def std(self) -> float: ...
[docs] def min(self) -> float: ...
[docs] def flatten(self: T_co) -> T_co: ...
[docs] def reshape( self: T_co, *shape: typing.SupportsIndex, order: typing.Literal["A", "C", "F"] | None = ..., ) -> T_co: ...
[docs] def transpose(self: T_co, *axes: int) -> T_co: ...