# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2021 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""Visualization tools."""
import nibabel as nb
import numpy as np
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
File,
SimpleInterface,
TraitedSpec,
isdefined,
traits,
)
from nipype.utils.filemanip import fname_presuffix
from niworkflows.utils.timeseries import _cifti_timeseries, _nifti_timeseries
from niworkflows.viz.plots import (
compcor_variance_plot,
confounds_correlation_plot,
fMRIPlot,
)
class _FMRISummaryInputSpec(BaseInterfaceInputSpec):
in_func = File(exists=True, mandatory=True, desc='')
in_spikes_bg = File(exists=True, desc='')
fd = File(exists=True, desc='')
dvars = File(exists=True, desc='')
outliers = File(exists=True, desc='')
in_segm = File(exists=True, desc='')
tr = traits.Either(None, traits.Float, usedefault=True, desc='the TR')
fd_thres = traits.Float(0.2, usedefault=True, desc='')
drop_trs = traits.Int(0, usedefault=True, desc='dummy scans')
class _FMRISummaryOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='written file path')
[docs]
class FMRISummary(SimpleInterface):
"""Prepare an fMRI summary plot for the report."""
input_spec = _FMRISummaryInputSpec
output_spec = _FMRISummaryOutputSpec
def _run_interface(self, runtime):
import pandas as pd
self._results['out_file'] = fname_presuffix(
self.inputs.in_func,
suffix='_fmriplot.svg',
use_ext=False,
newpath=runtime.cwd,
)
dataframe = (
pd.DataFrame(
{
'outliers': np.loadtxt(self.inputs.outliers, usecols=[0]).tolist(),
# Pick non-standardize dvars (col 1)
# First timepoint is NaN (difference)
'DVARS': [np.nan]
+ np.loadtxt(self.inputs.dvars, skiprows=1, usecols=[1]).tolist(),
# First timepoint is zero (reference volume)
'FD': [0.0] + np.loadtxt(self.inputs.fd, skiprows=1, usecols=[0]).tolist(),
}
)
if (
isdefined(self.inputs.outliers)
and isdefined(self.inputs.dvars)
and isdefined(self.inputs.fd)
)
else None
)
input_data = nb.load(self.inputs.in_func)
seg_file = self.inputs.in_segm if isdefined(self.inputs.in_segm) else None
dataset, segments = (
_cifti_timeseries(input_data)
if isinstance(input_data, nb.Cifti2Image)
else _nifti_timeseries(input_data, seg_file)
)
fig = fMRIPlot(
dataset,
segments=segments,
spikes_files=(
[self.inputs.in_spikes_bg] if isdefined(self.inputs.in_spikes_bg) else None
),
tr=(self.inputs.tr if isdefined(self.inputs.tr) else _get_tr(input_data)),
confounds=dataframe,
units={'outliers': '%', 'FD': 'mm'},
vlines={'FD': [self.inputs.fd_thres]},
nskip=self.inputs.drop_trs,
).plot()
fig.savefig(self._results['out_file'], bbox_inches='tight')
return runtime
class _CompCorVariancePlotInputSpec(BaseInterfaceInputSpec):
metadata_files = traits.List(
File(exists=True),
mandatory=True,
desc='List of files containing component metadata',
)
metadata_sources = traits.List(
traits.Str,
desc='List of names of decompositions '
'(e.g., aCompCor, tCompCor) yielding '
'the arguments in `metadata_files`',
)
variance_thresholds = traits.Tuple(
traits.Float(0.5),
traits.Float(0.7),
traits.Float(0.9),
usedefault=True,
desc='Levels of explained variance to include in plot',
)
out_file = traits.Either(None, File, value=None, usedefault=True, desc='Path to save plot')
class _CompCorVariancePlotOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='Path to saved plot')
[docs]
class CompCorVariancePlot(SimpleInterface):
"""Plot the number of components necessary to explain the specified levels of variance."""
input_spec = _CompCorVariancePlotInputSpec
output_spec = _CompCorVariancePlotOutputSpec
def _run_interface(self, runtime):
if self.inputs.out_file is None:
self._results['out_file'] = fname_presuffix(
self.inputs.metadata_files[0],
suffix='_compcor.svg',
use_ext=False,
newpath=runtime.cwd,
)
else:
self._results['out_file'] = self.inputs.out_file
compcor_variance_plot(
metadata_files=self.inputs.metadata_files,
metadata_sources=self.inputs.metadata_sources,
output_file=self._results['out_file'],
varexp_thresh=self.inputs.variance_thresholds,
)
return runtime
class _ConfoundsCorrelationPlotInputSpec(BaseInterfaceInputSpec):
confounds_file = File(exists=True, mandatory=True, desc='File containing confound regressors')
out_file = traits.Either(None, File, value=None, usedefault=True, desc='Path to save plot')
reference_column = traits.Str(
'global_signal',
usedefault=True,
desc='Column in the confound file for '
'which all correlation magnitudes '
'should be ranked and plotted',
)
columns = traits.List(traits.Str, desc='Filter out all regressors not found in this list.')
max_dim = traits.Int(
20,
usedefault=True,
desc='Maximum number of regressors to include in '
'plot. Regressors with highest magnitude of '
'correlation with `reference_column` will be '
'selected.',
)
ignore_initial_volumes = traits.Int(
0,
usedefault=True,
desc='Number of non-steady-state volumes at the beginning of the scan to ignore.',
)
class _ConfoundsCorrelationPlotOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='Path to saved plot')
[docs]
class ConfoundsCorrelationPlot(SimpleInterface):
"""Plot the correlation among confound regressors."""
input_spec = _ConfoundsCorrelationPlotInputSpec
output_spec = _ConfoundsCorrelationPlotOutputSpec
def _run_interface(self, runtime):
if self.inputs.out_file is None:
self._results['out_file'] = fname_presuffix(
self.inputs.confounds_file,
suffix='_confoundCorrelation.svg',
use_ext=False,
newpath=runtime.cwd,
)
else:
self._results['out_file'] = self.inputs.out_file
confounds_correlation_plot(
confounds_file=self.inputs.confounds_file,
columns=self.inputs.columns if isdefined(self.inputs.columns) else None,
max_dim=self.inputs.max_dim,
output_file=self._results['out_file'],
reference=self.inputs.reference_column,
ignore_initial_volumes=self.inputs.ignore_initial_volumes,
)
return runtime
def _get_tr(img):
"""
Attempt to extract repetition time from NIfTI/CIFTI header
Examples
--------
>>> _get_tr(nb.load(Path(test_data) /
... 'sub-ds205s03_task-functionallocalizer_run-01_bold_volreg.nii.gz'))
2.2
>>> _get_tr(nb.load(Path(test_data) /
... 'sub-01_task-mixedgamblestask_run-02_space-fsLR_den-91k_bold.dtseries.nii'))
2.0
"""
try:
return img.header.matrix.get_index_map(0).series_step
except AttributeError:
return img.header.get_zooms()[-1]
raise RuntimeError('Could not extract TR - unknown data structure type')