Source code for niworkflows.engine.splicer

"""Workflow splicing operations."""

import logging
import typing as ty
from functools import wraps

import nipype.pipeline.engine as pe
from nipype.pipeline.engine.base import EngineBase


[docs] def tag(tag: str) -> ty.Callable: """ Decorator to set a tag on an `init_...wf` function. This is used to mark nodes or workflows for replacement in the splicing process. """ def _decorator(func) -> ty.Callable: @wraps(func) def _tag(*args, **kwargs) -> EngineBase: node = func(*args, **kwargs) node._tag = tag return node return _tag return _decorator
[docs] def splice_workflow( root_wf: pe.Workflow, replacements: dict[str, EngineBase], *, write_graph: bool = False, debug: bool = False, ): """ Splice a workflow's tagged nodes / workflows and replace connections with alternatives. Requires that the workflow has been tagged with a `_tag` attribute. """ if write_graph: root_wf.write_graph('pre-slice.dot', format='png', graph2use='colored') substitutions = _get_substitutions(root_wf, replacements) _splice_components(root_wf, substitutions, debug=debug) if write_graph: root_wf.write_graph('post-slice.dot', format='png', graph2use='colored') return root_wf
def _get_substitutions( workflow: pe.Workflow, replacements: dict[str, EngineBase], ) -> dict[EngineBase, EngineBase]: """ "Query tags in workflow, and return a list of substitutions to make""" substitutions = {} tagged_wfs = _fetch_tags(workflow) for tag in tagged_wfs: if tag in replacements: substitutions[tagged_wfs[tag]] = replacements[tag] return substitutions def _fetch_tags(wf: pe.Workflow) -> dict[str, EngineBase]: """Query all nodes in a workflow and return a dictionary of tags and nodes.""" tagged = {} for node in wf._graph.nodes: if hasattr(node, '_tag'): tagged[node._tag] = node if isinstance(node, pe.Workflow): inner_tags = _fetch_tags(node) tagged.update(inner_tags) return tagged def _splice_components( workflow: pe.Workflow, substitutions: dict[EngineBase, EngineBase], debug: bool = False, ) -> tuple[list, list]: """Query all connections and return a list of removals and additions to be made.""" edge_removals = [] edge_connects = [] node_removals = set() node_adds = set() _expanded_workflows = set() to_replace = [x.fullname for x in substitutions] for src, dst in workflow._graph.edges: if dst.fullname in to_replace: edge_data = workflow._graph.get_edge_data(src, dst) alt_dst = substitutions[dst] alt_dst._hierarchy = dst._hierarchy edge_removals.append((src, dst)) node_removals.add(dst) node_adds.add(alt_dst) edge_connects.append((src, alt_dst, edge_data)) elif src.fullname in to_replace: edge_data = workflow._graph.get_edge_data(src, dst) alt_src = substitutions[src] alt_src._hierarchy = src._hierarchy edge_removals.append((src, dst)) node_removals.add(src) node_adds.add(alt_src) edge_connects.append((alt_src, dst, edge_data)) elif isinstance(dst, pe.Workflow) and dst not in _expanded_workflows: _expanded_workflows.add(dst) _splice_components(dst, substitutions, debug=debug) elif isinstance(src, pe.Workflow) and src not in _expanded_workflows: _expanded_workflows.add(src) _splice_components(src, substitutions, debug=debug) logger = logging.getLogger('nipype.workflow') logger.debug( 'Workflow: %s, \n- edge_removals: %s, \n+ edge_connects: %s', workflow, edge_removals, edge_connects, ) workflow._graph.remove_edges_from(edge_removals) workflow.remove_nodes(node_removals) workflow.add_nodes(node_adds) workflow._graph.add_edges_from(edge_connects)