# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2021 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""A robust ANTs T1-to-MNI registration workflow with fallback retry."""
from multiprocessing import cpu_count
from os import path as op
import numpy as np
from nipype.interfaces.ants import AffineInitializer
from nipype.interfaces.ants.registration import RegistrationOutputSpec
from nipype.interfaces.base import (
BaseInterface,
BaseInterfaceInputSpec,
File,
Str,
isdefined,
traits,
)
from packaging.version import Version
from templateflow.api import get as get_template
from .. import NIWORKFLOWS_LOG, __version__
from ..data import load as load_data
from .fixes import FixHeaderRegistration as Registration
niworkflows_version = Version(__version__)
class _SpatialNormalizationInputSpec(BaseInterfaceInputSpec):
# Enable deprecation
package_version = niworkflows_version
# Moving image.
moving_image = File(exists=True, mandatory=True, desc='image to apply transformation to')
# Reference image (optional).
reference_image = File(exists=True, desc='override the reference image')
# Moving mask (optional).
moving_mask = File(exists=True, desc='moving image mask')
# Reference mask (optional).
reference_mask = File(exists=True, desc='reference image mask')
# Lesion mask (optional).
lesion_mask = File(exists=True, desc='lesion mask image')
# Number of threads to use for ANTs/ITK processes.
num_threads = traits.Int(
cpu_count(), usedefault=True, nohash=True, desc='Number of ITK threads to use'
)
# ANTs parameter set to use.
flavor = traits.Enum(
'precise',
'testing',
'fast',
usedefault=True,
desc='registration settings parameter set',
)
# Template orientation.
orientation = traits.Enum(
'RAS',
'LAS',
mandatory=True,
usedefault=True,
desc='modify template orientation (should match input image)',
)
# Modality of the reference image.
reference = traits.Enum(
'T1w',
'T2w',
'boldref',
'PDw',
mandatory=True,
usedefault=True,
desc='set the reference modality for registration',
)
# T1 or EPI registration?
moving = traits.Enum(
'T1w', 'boldref', usedefault=True, mandatory=True, desc='registration type'
)
# Template to use as the default reference image.
template = traits.Str(
'MNI152NLin2009cAsym', usedefault=True, desc='define the template to be used'
)
# Load other settings from file.
settings = traits.List(File(exists=True), desc='pass on the list of settings files')
# Resolution of the default template.
template_spec = traits.Dict(Str, desc='template specifications')
template_resolution = traits.Enum(1, 2, None, desc='(DEPRECATED) template resolution')
# Use explicit masking?
explicit_masking = traits.Bool(
True,
usedefault=True,
desc="""\
Set voxels outside the masks to zero thus creating an artificial border
that can drive the registration. Requires reliable and accurate masks.
See https://sourceforge.net/p/advants/discussion/840261/thread/27216e69/#c7ba\
""",
)
initial_moving_transform = File(exists=True, desc='transform for initialization')
use_histogram_matching = traits.Bool(desc='determine use of histogram matching')
float = traits.Bool(False, usedefault=True, desc='use single precision calculations')
class _SpatialNormalizationOutputSpec(RegistrationOutputSpec):
reference_image = File(exists=True, desc='reference image used for registration target')
[docs]
class SpatialNormalization(BaseInterface):
"""
An interface to robustly run T1-to-MNI spatial normalization.
Several settings are sequentially tried until some work.
"""
input_spec = _SpatialNormalizationInputSpec
output_spec = _SpatialNormalizationOutputSpec
def _list_outputs(self):
outputs = self.norm._list_outputs()
outputs['reference_image'] = self._reference_image
return outputs
def __init__(self, **inputs):
self.norm = None
self._reference_image = None
self.retry = 1
self.terminal_output = 'file'
super().__init__(**inputs)
def _get_settings(self):
"""
Return any settings defined by the user, as well as any pre-defined
settings files that exist for the image modalities to be registered.
"""
# If user-defined settings exist...
if isdefined(self.inputs.settings):
# Note this in the log and return those settings.
NIWORKFLOWS_LOG.info('User-defined settings, overriding defaults')
return self.inputs.settings
data_dir = load_data()
# Get a list of settings files that match the flavor.
return sorted(
[
str(path)
for path in data_dir.glob(
f'{self.inputs.moving.lower()}-mni_registration_{self.inputs.flavor}_*.json'
)
]
)
def _run_interface(self, runtime):
# Get a list of settings files.
settings_files = self._get_settings()
ants_args = self._get_ants_args()
if not isdefined(self.inputs.initial_moving_transform):
NIWORKFLOWS_LOG.info('Estimating initial transform using AffineInitializer')
init = AffineInitializer(
fixed_image=ants_args['fixed_image'],
moving_image=ants_args['moving_image'],
num_threads=self.inputs.num_threads,
)
init.resource_monitor = False
init.terminal_output = 'allatonce'
init_result = init.run()
# Save outputs (if available)
init_out = _write_outputs(init_result.runtime, '.nipype-init')
if init_out:
NIWORKFLOWS_LOG.info(
'Terminal outputs of initialization saved (%s).',
', '.join(init_out),
)
ants_args['initial_moving_transform'] = init_result.outputs.out_file
# For each settings file...
for ants_settings in settings_files:
NIWORKFLOWS_LOG.info('Loading settings from file %s.', ants_settings)
# Configure an ANTs run based on these settings.
self.norm = Registration(from_file=ants_settings, **ants_args)
if isdefined(self.inputs.use_histogram_matching):
# Most (all?) configuration files use histogram matching, so more important
# to allow disabling, such as in the case of intermodality normalization
NIWORKFLOWS_LOG.info(
'Overriding (%sabling) histogram matching for file %s',
'en' if self.inputs.use_histogram_matching else 'dis',
ants_settings,
)
self.norm.inputs.use_histogram_matching = self.inputs.use_histogram_matching
self.norm.resource_monitor = False
self.norm.terminal_output = self.terminal_output
cmd = self.norm.cmdline
# Print the retry number and command line call to the log.
NIWORKFLOWS_LOG.info('Retry #%d, commandline: \n%s', self.retry, cmd)
self.norm.ignore_exception = True
with open('command.txt', 'w') as cmdfile:
print(cmd + '\n', file=cmdfile)
# Try running registration.
interface_result = self.norm.run()
if interface_result.runtime.returncode != 0:
NIWORKFLOWS_LOG.warning('Retry #%d failed.', self.retry)
# Save outputs (if available)
term_out = _write_outputs(interface_result.runtime, f'.nipype-{self.retry:04d}')
if term_out:
NIWORKFLOWS_LOG.warning('Log of failed retry saved (%s).', ', '.join(term_out))
else:
runtime.returncode = 0
# Note this in the log.
NIWORKFLOWS_LOG.info('Successful spatial normalization (retry #%d).', self.retry)
# Break out of the retry loop.
return runtime
self.retry += 1
# If all tries fail, raise an error.
raise RuntimeError(f'Robust spatial normalization failed after {self.retry - 1} retries.')
def _get_ants_args(self):
args = {
'moving_image': self.inputs.moving_image,
'num_threads': self.inputs.num_threads,
'float': self.inputs.float,
'terminal_output': 'file',
'write_composite_transform': True,
'initial_moving_transform': self.inputs.initial_moving_transform,
}
"""
Moving image handling - The following truth table maps out the intended action
sequence. Future refactoring may more directly encode this.
moving_mask and lesion_mask are files
True = file
False = None
| moving_mask | explicit_masking | lesion_mask | action
|-------------|------------------|-------------|-------------------------------------------
| True | True | True | Update `moving_image` after applying
| | | | mask.
| | | | Set `moving_image_masks` applying
| | | | `create_cfm` with `global_mask=True`.
|-------------|------------------|-------------|-------------------------------------------
| True | True | False | Update `moving_image` after applying
| | | | mask.
|-------------|------------------|-------------|-------------------------------------------
| True | False | True | Set `moving_image_masks` applying
| | | | `create_cfm` with `global_mask=False`
|-------------|------------------|-------------|-------------------------------------------
| True | False | False | args['moving_image_masks'] = moving_mask
|-------------|------------------|-------------|-------------------------------------------
| False | * | True | Set `moving_image_masks` applying
| | | | `create_cfm` with `global_mask=True`
|-------------|------------------|-------------|-------------------------------------------
| False | * | False | No action
"""
# If a moving mask is provided...
if isdefined(self.inputs.moving_mask):
# If explicit masking is enabled...
if self.inputs.explicit_masking:
# Mask the moving image.
# Do not use a moving mask during registration.
args['moving_image'] = mask(
self.inputs.moving_image,
self.inputs.moving_mask,
'moving_masked.nii.gz',
)
# If explicit masking is disabled...
else:
# Use the moving mask during registration.
# Do not mask the moving image.
args['moving_image_masks'] = self.inputs.moving_mask
# If a lesion mask is also provided...
if isdefined(self.inputs.lesion_mask):
# Create a cost function mask with the form:
# [global mask - lesion mask] (if explicit masking is enabled)
# [moving mask - lesion mask] (if explicit masking is disabled)
# Use this as the moving mask.
args['moving_image_masks'] = create_cfm(
self.inputs.moving_mask,
lesion_mask=self.inputs.lesion_mask,
global_mask=self.inputs.explicit_masking,
)
# If no moving mask is provided...
# But a lesion mask *IS* provided...
elif isdefined(self.inputs.lesion_mask):
# Create a cost function mask with the form: [global mask - lesion mask]
# Use this as the moving mask.
args['moving_image_masks'] = create_cfm(
self.inputs.moving_image,
lesion_mask=self.inputs.lesion_mask,
global_mask=True,
)
"""
Reference image handling - The following truth table maps out the intended action
sequence. Future refactoring may more directly encode this.
reference_mask and lesion_mask are files
True = file
False = None
| reference_mask | explicit_masking | lesion_mask | action
|----------------|------------------|-------------|----------------------------------------
| True | True | True | Update `fixed_image` after applying
| | | | mask.
| | | | Set `fixed_image_masks` applying
| | | | `create_cfm` with `global_mask=True`.
|----------------|------------------|-------------|----------------------------------------
| True | True | False | Update `fixed_image` after applying
| | | | mask.
|----------------|------------------|-------------|----------------------------------------
| True | False | True | Set `fixed_image_masks` applying
| | | | `create_cfm` with `global_mask=False`
|----------------|------------------|-------------|----------------------------------------
| True | False | False | args['fixed_image_masks'] = fixed_mask
|----------------|------------------|-------------|----------------------------------------
| False | * | True | Set `fixed_image_masks` applying
| | | | `create_cfm` with `global_mask=True`
|----------------|------------------|-------------|----------------------------------------
| False | * | False | No action
"""
# If a reference image is provided...
if isdefined(self.inputs.reference_image):
# Use the reference image as the fixed image.
args['fixed_image'] = self.inputs.reference_image
self._reference_image = self.inputs.reference_image
# If a reference mask is provided...
if isdefined(self.inputs.reference_mask):
# If explicit masking is enabled...
if self.inputs.explicit_masking:
# Mask the reference image.
# Do not use a fixed mask during registration.
args['fixed_image'] = mask(
self.inputs.reference_image,
self.inputs.reference_mask,
'fixed_masked.nii.gz',
)
# If a lesion mask is also provided...
if isdefined(self.inputs.lesion_mask):
# Create a cost function mask with the form: [global mask]
# Use this as the fixed mask.
args['fixed_image_masks'] = create_cfm(
self.inputs.reference_mask,
lesion_mask=None,
global_mask=True,
)
# If a reference mask is provided...
# But explicit masking is disabled...
else:
# Use the reference mask as the fixed mask during registration.
# Do not mask the fixed image.
args['fixed_image_masks'] = self.inputs.reference_mask
# If no reference mask is provided...
# But a lesion mask *IS* provided ...
elif isdefined(self.inputs.lesion_mask):
# Create a cost function mask with the form: [global mask]
# Use this as the fixed mask
args['fixed_image_masks'] = create_cfm(
self.inputs.reference_image, lesion_mask=None, global_mask=True
)
# If no reference image is provided, fall back to the default template.
else:
from ..utils.misc import get_template_specs
# Raise an error if the user specifies an unsupported image orientation.
if self.inputs.orientation == 'LAS':
raise NotImplementedError
template_spec = (
self.inputs.template_spec if isdefined(self.inputs.template_spec) else {}
)
default_resolution = {'precise': 1, 'fast': 2, 'testing': 2}[self.inputs.flavor]
# Set the template resolution.
if isdefined(self.inputs.template_resolution):
NIWORKFLOWS_LOG.warning('The use of ``template_resolution`` is deprecated')
template_spec['res'] = self.inputs.template_resolution
template_spec['suffix'] = self.inputs.reference
template_spec['desc'] = None
ref_template, template_spec = get_template_specs(
self.inputs.template,
template_spec=template_spec,
default_resolution=default_resolution,
fallback=True,
)
# Set reference image
self._reference_image = ref_template
if not op.isfile(self._reference_image):
raise ValueError(
f"""\
The registration reference must be an existing file, but path "{ref_template}" \
cannot be found."""
)
# Get the template specified by the user.
ref_mask = get_template(
self.inputs.template, desc='brain', suffix='mask', **template_spec
) or get_template(self.inputs.template, label='brain', suffix='mask', **template_spec)
# Default is explicit masking disabled
args['fixed_image'] = ref_template
# Use the template mask as the fixed mask.
args['fixed_image_masks'] = str(ref_mask)
# Overwrite defaults if explicit masking
if self.inputs.explicit_masking:
# Mask the template image with the template mask.
args['fixed_image'] = mask(ref_template, str(ref_mask), 'fixed_masked.nii.gz')
# Do not use a fixed mask during registration.
args.pop('fixed_image_masks', None)
# If a lesion mask is provided...
if isdefined(self.inputs.lesion_mask):
# Create a cost function mask with the form: [global mask]
# Use this as the fixed mask.
args['fixed_image_masks'] = create_cfm(
str(ref_mask), lesion_mask=None, global_mask=True
)
return args
[docs]
def mask(in_file, mask_file, new_name):
"""
Apply a binary mask to an image.
Parameters
----------
in_file : str
Path to a NIfTI file to mask
mask_file : str
Path to a binary mask
new_name : str
Path/filename for the masked output image.
Returns
-------
str
Absolute path of the masked output image.
Notes
-----
in_file and mask_file must be in the same
image space and have the same dimensions.
"""
import os
import nibabel as nb
# Load the input image
in_nii = nb.load(in_file)
# Load the mask image
mask_nii = nb.load(mask_file)
# Set all non-mask voxels in the input file to zero.
data = in_nii.get_fdata()
data[np.asanyarray(mask_nii.dataobj) == 0] = 0
# Save the new masked image.
new_nii = nb.Nifti1Image(data, in_nii.affine, in_nii.header)
new_nii.to_filename(new_name)
return os.path.abspath(new_name)
[docs]
def create_cfm(in_file, lesion_mask=None, global_mask=True, out_path=None):
"""
Create a mask to constrain registration.
Parameters
----------
in_file : str
Path to an existing image (usually a mask).
If global_mask = True, this is used as a size/dimension reference.
out_path : str
Path/filename for the new cost function mask.
lesion_mask : str, optional
Path to an existing binary lesion mask.
global_mask : bool
Create a whole-image mask (True) or limit to reference mask (False)
A whole image-mask is 1 everywhere
Returns
-------
str
Absolute path of the new cost function mask.
Notes
-----
in_file and lesion_mask must be in the same
image space and have the same dimensions
"""
import os
import nibabel as nb
import numpy as np
from nipype.utils.filemanip import fname_presuffix
if out_path is None:
out_path = fname_presuffix(in_file, suffix='_cfm', newpath=os.getcwd())
else:
out_path = os.path.abspath(out_path)
if not global_mask and not lesion_mask:
NIWORKFLOWS_LOG.warning(
'No lesion mask was provided and global_mask not requested, '
'therefore the original mask will not be modified.'
)
# Load the input image
in_img = nb.load(in_file)
# If we want a global mask, create one based on the input image.
data = np.ones(in_img.shape, dtype=np.uint8) if global_mask else np.asanyarray(in_img.dataobj)
if set(np.unique(data)) - {0, 1}:
raise ValueError('`global_mask` must be true if `in_file` is not a binary mask')
# If a lesion mask was provided, combine it with the secondary mask.
if lesion_mask is not None:
# Reorient the lesion mask and get the data.
lm_img = nb.as_closest_canonical(nb.load(lesion_mask))
# Subtract lesion mask from secondary mask, set negatives to 0
data = np.fmax(data - lm_img.dataobj, 0)
# Cost function mask will be created from subtraction
# Otherwise, CFM will be created from global mask
cfm_img = nb.Nifti1Image(data, in_img.affine, in_img.header)
# Save the cost function mask.
cfm_img.set_data_dtype(np.uint8)
cfm_img.to_filename(out_path)
return out_path
def _write_outputs(runtime, out_fname=None):
if out_fname is None:
out_fname = '.nipype'
out_files = []
for name in ['stdout', 'stderr', 'merged']:
stream = getattr(runtime, name, '')
if stream:
out_file = op.join(runtime.cwd, name + out_fname)
with open(out_file, 'w') as outf:
print(stream, file=outf)
out_files.append(out_file)
return out_files