Source code for ewoksndreg.tasks.reg2d_transform

import datetime
import json
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from ewokscore.task import Task
from ewokscore.variable import Variable
from pydantic import Field
from pydantic import field_validator
from silx.io import h5py_utils
from silx.io.url import DataUrl

from ..io.input_stack import InputDataType
from ..io.input_stack import input_context
from ..io.nexus import nx_annotate
from ..io.output_stack import OutputDataTypeForInput
from ..io.output_stack import output_context
from ..transformation import apply_transformations
from ..transformation.base import Transformation


[docs] class Inputs(BaseInputModel): image_stacks: InputDataType = Field( description="Image stacks as a dictionary of numpy arrays or list of HDF5 dataset URI's.", examples=[ { "stack1": "/path/to/file.h5::/entry/process1/results/parameters/Ca-K", "stack2": "/path/to/file.h5::/entry/process1/results/parameters/Fe-K", }, {"stack1": [[0, 0, 0], [1, 1, 1], [2, 2, 2]]}, ], ) transformations: Dict[str, List[Transformation]] = Field( description="Transformations for each image in each stack." ) output_root_uri: Union[DataUrl, str, None] = Field( default=None, description="URL to save all transformed stacks.", examples=["/path/to/file.h5::/entry/process2/results/parameters/"], ) image_stacks_nxmetadata: Optional[dict] = Field( default=None, description="HDF5/NeXus metadata relative to the file root following the Silx dictdump schema.", examples=[{"@NX_class": "NXroot", "entry": {"@NX_class": "NXentry"}}], ) output_configuration: Optional[Dict[str, Any]] = Field( default=None, description="Registration configuration parameters to be saved.", examples=[{"param1": 0, "param2": 1}], ) crop: bool = Field( default=False, description="Crop Nan's at the image edges after alignment." ) interpolation_order: int = Field( default=1, description="Interpolation order when transforming an image." )
[docs] @field_validator("output_root_uri", mode="before") def coerce_uri(cls, var): if isinstance(var, str): return DataUrl(var) if isinstance(var, Variable) and isinstance(var.value, str): var.value = DataUrl(var.value) return var
[docs] class Outputs(BaseOutputModel): image_stacks: OutputDataTypeForInput = Field( description="Dictionary of image stacks in memory or URIs." ) output_configuration: Optional[Dict[str, Any]] = Field( description="Registration configuration parameters to be saved.", examples=[{"param1": 0, "param2": 1}], )
[docs] class Reg2DTransform(Task, input_model=Inputs, output_model=Outputs): """Apply transformations calculated from image registration to the images of one or more stacks."""
[docs] def run(self): image_stacks = self.inputs.image_stacks output_root_uri = self.get_input_value("output_root_uri", None) image_stacks_nxmetadata = self.get_input_value("image_stacks_nxmetadata", None) if output_root_uri: output_filenames = [output_root_uri.file_path()] else: output_filenames = None image_stacks_nxmetadata = None with output_context(output_root_uri) as ostacks: with input_context( image_stacks, output_filenames=output_filenames ) as istacks: image_crop_idx = apply_transformations( istacks, ostacks, self.inputs.transformations, crop=self.inputs.crop, interpolation_order=self.inputs.interpolation_order, ) aligned_stacks = ostacks.data_for_input() if image_stacks_nxmetadata: if image_crop_idx is not None: _crop_nxdata_axes( output_root_uri, list(aligned_stacks), image_stacks_nxmetadata, image_crop_idx, ) nx_annotate(image_stacks_nxmetadata, output_root_uri.file_path()) output_configuration = self.get_input_value("output_configuration") or dict() output_configuration["crop"] = self.inputs.crop output_configuration["interpolation_order"] = self.inputs.interpolation_order self.outputs.output_configuration = output_configuration if output_root_uri: _save_output_configuration( output_root_uri, list(aligned_stacks), output_configuration ) self.outputs.image_stacks = aligned_stacks
def _crop_nxdata_axes( output_root_uri: DataUrl, stack_names: List[str], image_stacks_nxmetadata: dict, image_crop_idx: Tuple[slice, ...], ): root_parts = output_root_uri.data_path().split("/") modified = set() expected_stack_ndim = len(image_crop_idx) + 1 for stack_name in stack_names: nxdata_parts = root_parts + stack_name.split("/")[:-1] nxdata_id = tuple(nxdata_parts) if nxdata_id in modified: continue modified.add(nxdata_id) nxdata = image_stacks_nxmetadata for name in nxdata_parts: if name: nxdata = nxdata.get(name, dict()) axes = nxdata.get("@axes") if axes is None: continue if len(axes) != expected_stack_ndim: raise ValueError( f"NXdata axes attribute must contain {expected_stack_ndim} names." ) for axis, axis_idx in zip(axes[1:], image_crop_idx): nxdata[axis] = nxdata[axis][axis_idx] def _save_output_configuration( output_root_uri: DataUrl, stack_names: List[str], output_configuration: dict ) -> None: root_parts = output_root_uri.data_path().split("/") nxdata_parts = root_parts + stack_names[0].split("/")[:-1] with h5py_utils.File(output_root_uri.file_path(), mode="a") as fh: nxprocess = fh for name in nxdata_parts: if name: nxprocess = nxprocess[name] if nxprocess.attrs.get("NX_class") == "NXprocess": break else: return nxnote = nxprocess.require_group("configuration") nxnote.attrs["NX_class"] = "NXnote" nxnote["type"] = "application/json" nxnote["data"] = json.dumps(output_configuration, indent=2) nxnote["date"] = datetime.datetime.now().astimezone().isoformat()