Source code for niworkflows.interfaces.norm

# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2021 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
#     https://www.nipreps.org/community/licensing/
#
"""A robust ANTs T1-to-MNI registration workflow with fallback retry."""
from os import path as op

from multiprocessing import cpu_count
from packaging.version import Version
import numpy as np

from nipype.interfaces.ants.registration import RegistrationOutputSpec
from nipype.interfaces.ants import AffineInitializer
from nipype.interfaces.base import (
    traits,
    isdefined,
    BaseInterface,
    BaseInterfaceInputSpec,
    File,
)

from templateflow.api import get as get_template
from .. import NIWORKFLOWS_LOG, __version__
from ..data import load as load_data
from .fixes import FixHeaderRegistration as Registration


niworkflows_version = Version(__version__)


class _SpatialNormalizationInputSpec(BaseInterfaceInputSpec):
    # Enable deprecation
    package_version = niworkflows_version

    # Moving image.
    moving_image = File(
        exists=True, mandatory=True, desc="image to apply transformation to"
    )
    # Reference image (optional).
    reference_image = File(exists=True, desc="override the reference image")
    # Moving mask (optional).
    moving_mask = File(exists=True, desc="moving image mask")
    # Reference mask (optional).
    reference_mask = File(exists=True, desc="reference image mask")
    # Lesion mask (optional).
    lesion_mask = File(exists=True, desc="lesion mask image")
    # Number of threads to use for ANTs/ITK processes.
    num_threads = traits.Int(
        cpu_count(), usedefault=True, nohash=True, desc="Number of ITK threads to use"
    )
    # ANTs parameter set to use.
    flavor = traits.Enum(
        "precise",
        "testing",
        "fast",
        usedefault=True,
        desc="registration settings parameter set",
    )
    # Template orientation.
    orientation = traits.Enum(
        "RAS",
        "LAS",
        mandatory=True,
        usedefault=True,
        desc="modify template orientation (should match input image)",
    )
    # Modality of the reference image.
    reference = traits.Enum(
        "T1w",
        "T2w",
        "boldref",
        "PDw",
        mandatory=True,
        usedefault=True,
        desc="set the reference modality for registration",
    )
    # T1 or EPI registration?
    moving = traits.Enum(
        "T1w", "boldref", usedefault=True, mandatory=True, desc="registration type"
    )
    # Template to use as the default reference image.
    template = traits.Str(
        "MNI152NLin2009cAsym", usedefault=True, desc="define the template to be used"
    )
    # Load other settings from file.
    settings = traits.List(File(exists=True), desc="pass on the list of settings files")
    # Resolution of the default template.
    template_spec = traits.DictStrAny(desc="template specifications")
    template_resolution = traits.Enum(1, 2, None, desc="(DEPRECATED) template resolution")
    # Use explicit masking?
    explicit_masking = traits.Bool(
        True,
        usedefault=True,
        desc="""\
Set voxels outside the masks to zero thus creating an artificial border
that can drive the registration. Requires reliable and accurate masks.
See https://sourceforge.net/p/advants/discussion/840261/thread/27216e69/#c7ba\
""",
    )
    initial_moving_transform = File(exists=True, desc="transform for initialization")
    float = traits.Bool(
        False, usedefault=True, desc="use single precision calculations"
    )


class _SpatialNormalizationOutputSpec(RegistrationOutputSpec):
    reference_image = File(
        exists=True, desc="reference image used for registration target"
    )


[docs] class SpatialNormalization(BaseInterface): """ An interface to robustly run T1-to-MNI spatial normalization. Several settings are sequentially tried until some work. """ input_spec = _SpatialNormalizationInputSpec output_spec = _SpatialNormalizationOutputSpec def _list_outputs(self): outputs = self.norm._list_outputs() outputs["reference_image"] = self._reference_image return outputs def __init__(self, **inputs): self.norm = None self._reference_image = None self.retry = 1 self.terminal_output = "file" super().__init__(**inputs) def _get_settings(self): """ Return any settings defined by the user, as well as any pre-defined settings files that exist for the image modalities to be registered. """ # If user-defined settings exist... if isdefined(self.inputs.settings): # Note this in the log and return those settings. NIWORKFLOWS_LOG.info("User-defined settings, overriding defaults") return self.inputs.settings # Define a prefix for output files based on the modality of the moving image. filestart = "{}-mni_registration_{}_".format( self.inputs.moving.lower(), self.inputs.flavor ) data_dir = load_data() # Get a list of settings files that match the flavor. filenames = [ path.name for path in data_dir.iterdir() if path.name.startswith(filestart) and path.name.endswith(".json") ] # Return the settings files. return [str(data_dir / f) for f in sorted(filenames)] def _run_interface(self, runtime): # Get a list of settings files. settings_files = self._get_settings() ants_args = self._get_ants_args() if not isdefined(self.inputs.initial_moving_transform): NIWORKFLOWS_LOG.info("Estimating initial transform using AffineInitializer") init = AffineInitializer( fixed_image=ants_args["fixed_image"], moving_image=ants_args["moving_image"], num_threads=self.inputs.num_threads, ) init.resource_monitor = False init.terminal_output = "allatonce" init_result = init.run() # Save outputs (if available) init_out = _write_outputs(init_result.runtime, ".nipype-init") if init_out: NIWORKFLOWS_LOG.info( "Terminal outputs of initialization saved (%s).", ", ".join(init_out), ) ants_args["initial_moving_transform"] = init_result.outputs.out_file # For each settings file... for ants_settings in settings_files: NIWORKFLOWS_LOG.info("Loading settings from file %s.", ants_settings) # Configure an ANTs run based on these settings. self.norm = Registration(from_file=ants_settings, **ants_args) self.norm.resource_monitor = False self.norm.terminal_output = self.terminal_output cmd = self.norm.cmdline # Print the retry number and command line call to the log. NIWORKFLOWS_LOG.info("Retry #%d, commandline: \n%s", self.retry, cmd) self.norm.ignore_exception = True with open("command.txt", "w") as cmdfile: print(cmd + "\n", file=cmdfile) # Try running registration. interface_result = self.norm.run() if interface_result.runtime.returncode != 0: NIWORKFLOWS_LOG.warning("Retry #%d failed.", self.retry) # Save outputs (if available) term_out = _write_outputs( interface_result.runtime, ".nipype-%04d" % self.retry ) if term_out: NIWORKFLOWS_LOG.warning( "Log of failed retry saved (%s).", ", ".join(term_out) ) else: runtime.returncode = 0 # Note this in the log. NIWORKFLOWS_LOG.info( "Successful spatial normalization (retry #%d).", self.retry ) # Break out of the retry loop. return runtime self.retry += 1 # If all tries fail, raise an error. raise RuntimeError( "Robust spatial normalization failed after %d retries." % (self.retry - 1) ) def _get_ants_args(self): args = { "moving_image": self.inputs.moving_image, "num_threads": self.inputs.num_threads, "float": self.inputs.float, "terminal_output": "file", "write_composite_transform": True, "initial_moving_transform": self.inputs.initial_moving_transform, } """ Moving image handling - The following truth table maps out the intended action sequence. Future refactoring may more directly encode this. moving_mask and lesion_mask are files True = file False = None | moving_mask | explicit_masking | lesion_mask | action |-------------|------------------|-------------|------------------------------------------- | True | True | True | Update `moving_image` after applying | | | | mask. | | | | Set `moving_image_masks` applying | | | | `create_cfm` with `global_mask=True`. |-------------|------------------|-------------|------------------------------------------- | True | True | False | Update `moving_image` after applying | | | | mask. |-------------|------------------|-------------|------------------------------------------- | True | False | True | Set `moving_image_masks` applying | | | | `create_cfm` with `global_mask=False` |-------------|------------------|-------------|------------------------------------------- | True | False | False | args['moving_image_masks'] = moving_mask |-------------|------------------|-------------|------------------------------------------- | False | * | True | Set `moving_image_masks` applying | | | | `create_cfm` with `global_mask=True` |-------------|------------------|-------------|------------------------------------------- | False | * | False | No action """ # If a moving mask is provided... if isdefined(self.inputs.moving_mask): # If explicit masking is enabled... if self.inputs.explicit_masking: # Mask the moving image. # Do not use a moving mask during registration. args["moving_image"] = mask( self.inputs.moving_image, self.inputs.moving_mask, "moving_masked.nii.gz", ) # If explicit masking is disabled... else: # Use the moving mask during registration. # Do not mask the moving image. args["moving_image_masks"] = self.inputs.moving_mask # If a lesion mask is also provided... if isdefined(self.inputs.lesion_mask): # Create a cost function mask with the form: # [global mask - lesion mask] (if explicit masking is enabled) # [moving mask - lesion mask] (if explicit masking is disabled) # Use this as the moving mask. args["moving_image_masks"] = create_cfm( self.inputs.moving_mask, lesion_mask=self.inputs.lesion_mask, global_mask=self.inputs.explicit_masking, ) # If no moving mask is provided... # But a lesion mask *IS* provided... elif isdefined(self.inputs.lesion_mask): # Create a cost function mask with the form: [global mask - lesion mask] # Use this as the moving mask. args["moving_image_masks"] = create_cfm( self.inputs.moving_image, lesion_mask=self.inputs.lesion_mask, global_mask=True, ) """ Reference image handling - The following truth table maps out the intended action sequence. Future refactoring may more directly encode this. reference_mask and lesion_mask are files True = file False = None | reference_mask | explicit_masking | lesion_mask | action |----------------|------------------|-------------|---------------------------------------- | True | True | True | Update `fixed_image` after applying | | | | mask. | | | | Set `fixed_image_masks` applying | | | | `create_cfm` with `global_mask=True`. |----------------|------------------|-------------|---------------------------------------- | True | True | False | Update `fixed_image` after applying | | | | mask. |----------------|------------------|-------------|---------------------------------------- | True | False | True | Set `fixed_image_masks` applying | | | | `create_cfm` with `global_mask=False` |----------------|------------------|-------------|---------------------------------------- | True | False | False | args['fixed_image_masks'] = fixed_mask |----------------|------------------|-------------|---------------------------------------- | False | * | True | Set `fixed_image_masks` applying | | | | `create_cfm` with `global_mask=True` |----------------|------------------|-------------|---------------------------------------- | False | * | False | No action """ # If a reference image is provided... if isdefined(self.inputs.reference_image): # Use the reference image as the fixed image. args["fixed_image"] = self.inputs.reference_image self._reference_image = self.inputs.reference_image # If a reference mask is provided... if isdefined(self.inputs.reference_mask): # If explicit masking is enabled... if self.inputs.explicit_masking: # Mask the reference image. # Do not use a fixed mask during registration. args["fixed_image"] = mask( self.inputs.reference_image, self.inputs.reference_mask, "fixed_masked.nii.gz", ) # If a lesion mask is also provided... if isdefined(self.inputs.lesion_mask): # Create a cost function mask with the form: [global mask] # Use this as the fixed mask. args["fixed_image_masks"] = create_cfm( self.inputs.reference_mask, lesion_mask=None, global_mask=True, ) # If a reference mask is provided... # But explicit masking is disabled... else: # Use the reference mask as the fixed mask during registration. # Do not mask the fixed image. args["fixed_image_masks"] = self.inputs.reference_mask # If no reference mask is provided... # But a lesion mask *IS* provided ... elif isdefined(self.inputs.lesion_mask): # Create a cost function mask with the form: [global mask] # Use this as the fixed mask args["fixed_image_masks"] = create_cfm( self.inputs.reference_image, lesion_mask=None, global_mask=True ) # If no reference image is provided, fall back to the default template. else: from ..utils.misc import get_template_specs # Raise an error if the user specifies an unsupported image orientation. if self.inputs.orientation == "LAS": raise NotImplementedError template_spec = ( self.inputs.template_spec if isdefined(self.inputs.template_spec) else {} ) default_resolution = {"precise": 1, "fast": 2, "testing": 2}[ self.inputs.flavor ] # Set the template resolution. if isdefined(self.inputs.template_resolution): NIWORKFLOWS_LOG.warning( "The use of ``template_resolution`` is deprecated" ) template_spec["res"] = self.inputs.template_resolution template_spec["suffix"] = self.inputs.reference template_spec["desc"] = None ref_template, template_spec = get_template_specs( self.inputs.template, template_spec=template_spec, default_resolution=default_resolution, fallback=True, ) # Set reference image self._reference_image = ref_template if not op.isfile(self._reference_image): raise ValueError( """\ The registration reference must be an existing file, but path "%s" \ cannot be found.""" % ref_template ) # Get the template specified by the user. ref_mask = get_template( self.inputs.template, desc="brain", suffix="mask", **template_spec ) or get_template(self.inputs.template, label="brain", suffix="mask", **template_spec) # Default is explicit masking disabled args["fixed_image"] = ref_template # Use the template mask as the fixed mask. args["fixed_image_masks"] = str(ref_mask) # Overwrite defaults if explicit masking if self.inputs.explicit_masking: # Mask the template image with the template mask. args["fixed_image"] = mask( ref_template, str(ref_mask), "fixed_masked.nii.gz" ) # Do not use a fixed mask during registration. args.pop("fixed_image_masks", None) # If a lesion mask is provided... if isdefined(self.inputs.lesion_mask): # Create a cost function mask with the form: [global mask] # Use this as the fixed mask. args["fixed_image_masks"] = create_cfm( str(ref_mask), lesion_mask=None, global_mask=True ) return args
[docs] def mask(in_file, mask_file, new_name): """ Apply a binary mask to an image. Parameters ---------- in_file : str Path to a NIfTI file to mask mask_file : str Path to a binary mask new_name : str Path/filename for the masked output image. Returns ------- str Absolute path of the masked output image. Notes ----- in_file and mask_file must be in the same image space and have the same dimensions. """ import nibabel as nb import os # Load the input image in_nii = nb.load(in_file) # Load the mask image mask_nii = nb.load(mask_file) # Set all non-mask voxels in the input file to zero. data = in_nii.get_fdata() data[np.asanyarray(mask_nii.dataobj) == 0] = 0 # Save the new masked image. new_nii = nb.Nifti1Image(data, in_nii.affine, in_nii.header) new_nii.to_filename(new_name) return os.path.abspath(new_name)
[docs] def create_cfm(in_file, lesion_mask=None, global_mask=True, out_path=None): """ Create a mask to constrain registration. Parameters ---------- in_file : str Path to an existing image (usually a mask). If global_mask = True, this is used as a size/dimension reference. out_path : str Path/filename for the new cost function mask. lesion_mask : str, optional Path to an existing binary lesion mask. global_mask : bool Create a whole-image mask (True) or limit to reference mask (False) A whole image-mask is 1 everywhere Returns ------- str Absolute path of the new cost function mask. Notes ----- in_file and lesion_mask must be in the same image space and have the same dimensions """ import os import numpy as np import nibabel as nb from nipype.utils.filemanip import fname_presuffix if out_path is None: out_path = fname_presuffix(in_file, suffix="_cfm", newpath=os.getcwd()) else: out_path = os.path.abspath(out_path) if not global_mask and not lesion_mask: NIWORKFLOWS_LOG.warning( "No lesion mask was provided and global_mask not requested, " "therefore the original mask will not be modified." ) # Load the input image in_img = nb.load(in_file) # If we want a global mask, create one based on the input image. data = ( np.ones(in_img.shape, dtype=np.uint8) if global_mask else np.asanyarray(in_img.dataobj) ) if set(np.unique(data)) - {0, 1}: raise ValueError("`global_mask` must be true if `in_file` is not a binary mask") # If a lesion mask was provided, combine it with the secondary mask. if lesion_mask is not None: # Reorient the lesion mask and get the data. lm_img = nb.as_closest_canonical(nb.load(lesion_mask)) # Subtract lesion mask from secondary mask, set negatives to 0 data = np.fmax(data - lm_img.dataobj, 0) # Cost function mask will be created from subtraction # Otherwise, CFM will be created from global mask cfm_img = nb.Nifti1Image(data, in_img.affine, in_img.header) # Save the cost function mask. cfm_img.set_data_dtype(np.uint8) cfm_img.to_filename(out_path) return out_path
def _write_outputs(runtime, out_fname=None): if out_fname is None: out_fname = ".nipype" out_files = [] for name in ["stdout", "stderr", "merged"]: stream = getattr(runtime, name, "") if stream: out_file = op.join(runtime.cwd, name + out_fname) with open(out_file, "w") as outf: print(stream, file=outf) out_files.append(out_file) return out_files