# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2021 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""ITK files handling."""
import os
from mimetypes import guess_type
from tempfile import TemporaryDirectory
import nibabel as nb
import nitransforms as nt
import numpy as np
from nipype import logging
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
File,
InputMultiObject,
OutputMultiObject,
SimpleInterface,
TraitedSpec,
isdefined,
traits,
)
from nipype.utils.filemanip import fname_presuffix
from .fixes import _FixTraitApplyTransformsInputSpec
LOGGER = logging.getLogger('nipype.interface')
class _MCFLIRT2ITKInputSpec(BaseInterfaceInputSpec):
in_files = InputMultiObject(
File(exists=True), mandatory=True, desc='list of MAT files from MCFLIRT'
)
in_reference = File(exists=True, mandatory=True, desc='input image for spatial reference')
in_source = File(exists=True, mandatory=True, desc='input image for spatial source')
num_threads = traits.Int(nohash=True, desc='number of parallel processes')
class _MCFLIRT2ITKOutputSpec(TraitedSpec):
out_file = File(desc='the output ITKTransform file')
[docs]
class MCFLIRT2ITK(SimpleInterface):
"""Convert a list of MAT files from MCFLIRT into an ITK Transform file."""
input_spec = _MCFLIRT2ITKInputSpec
output_spec = _MCFLIRT2ITKOutputSpec
def _run_interface(self, runtime):
if isdefined(self.inputs.num_threads):
LOGGER.warning('Multithreading is deprecated. Remove the num_threads input.')
source = nb.load(self.inputs.in_source)
reference = nb.load(self.inputs.in_reference)
affines = [
nt.linear.load(mat, fmt='fsl', reference=reference, moving=source)
for mat in self.inputs.in_files
]
affarray = nt.io.itk.ITKLinearTransformArray.from_ras(
np.stack([a.matrix for a in affines], axis=0),
)
self._results['out_file'] = os.path.join(runtime.cwd, 'mat2itk.txt')
affarray.to_filename(self._results['out_file'])
return runtime
class _MultiApplyTransformsInputSpec(_FixTraitApplyTransformsInputSpec):
input_image = InputMultiObject(
File(exists=True),
mandatory=True,
desc='input time-series as a list of volumes after splitting through the fourth dimension',
)
num_threads = traits.Int(1, usedefault=True, nohash=True, desc='number of parallel processes')
save_cmd = traits.Bool(
True, usedefault=True, desc='write a log of command lines that were applied'
)
copy_dtype = traits.Bool(False, usedefault=True, desc='copy dtype from inputs to outputs')
class _MultiApplyTransformsOutputSpec(TraitedSpec):
out_files = OutputMultiObject(File(), desc='the output ITKTransform file')
log_cmdline = File(desc='a list of command lines used to apply transforms')
def _applytfms(args):
"""
Applies ANTs' antsApplyTransforms to the input image.
All inputs are zipped in one tuple to make it digestible by
multiprocessing's map
"""
import nibabel as nb
from nipype.utils.filemanip import fname_presuffix
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
in_file, in_xform, ifargs, index, newpath = args
out_file = fname_presuffix(
in_file, suffix=f'_xform-{index:05d}', newpath=newpath, use_ext=True
)
copy_dtype = ifargs.pop('copy_dtype', False)
xfm = ApplyTransforms(
input_image=in_file, transforms=in_xform, output_image=out_file, **ifargs
)
xfm.terminal_output = 'allatonce'
xfm.resource_monitor = False
runtime = xfm.run().runtime
if copy_dtype:
nii = nb.load(out_file, mmap=False)
in_dtype = nb.load(in_file).get_data_dtype()
# Overwrite only iff dtypes don't match
if in_dtype != nii.get_data_dtype():
nii.set_data_dtype(in_dtype)
nii.to_filename(out_file)
return (out_file, runtime.cmdline)
def _arrange_xfms(transforms, num_files, tmp_folder):
"""
Convenience method to arrange the list of transforms that should be applied
to each input file
"""
base_xform = ['#Insight Transform File V1.0', '#Transform 0']
# Initialize the transforms matrix
xfms_T = []
for i, tf_file in enumerate(transforms):
if tf_file == 'identity':
xfms_T.append([tf_file] * num_files)
continue
# If it is a deformation field, copy to the tfs_matrix directly
if guess_type(tf_file)[0] != 'text/plain':
xfms_T.append([tf_file] * num_files)
continue
with open(tf_file) as tf_fh:
tfdata = tf_fh.read().strip()
# If it is not an ITK transform file, copy to the tfs_matrix directly
if not tfdata.startswith('#Insight Transform File'):
xfms_T.append([tf_file] * num_files)
continue
# Count number of transforms in ITK transform file
nxforms = tfdata.count('#Transform')
# Remove first line
tfdata = tfdata.split('\n')[1:]
# If it is a ITK transform file with only 1 xform, copy to the tfs_matrix directly
if nxforms == 1:
xfms_T.append([tf_file] * num_files)
continue
if nxforms != num_files:
raise RuntimeError(
f'Number of transforms ({nxforms}) found in the ITK file does not'
f' match the number of input image files ({num_files}).'
)
# At this point splitting transforms will be necessary, generate a base name
out_base = fname_presuffix(
tf_file, suffix=f'_pos-{i:03d}_xfm-{{:05d}}', newpath=tmp_folder.name
).format
# Split combined ITK transforms file
split_xfms = []
for xform_i in range(nxforms):
# Find start token to extract
startidx = tfdata.index(f'#Transform {xform_i}')
next_xform = base_xform + tfdata[startidx + 1 : startidx + 4] + ['']
xfm_file = out_base(xform_i)
with open(xfm_file, 'w') as out_xfm:
out_xfm.write('\n'.join(next_xform))
split_xfms.append(xfm_file)
xfms_T.append(split_xfms)
# Transpose back (only Python 3)
return list(map(list, zip(*xfms_T)))