Source code for ewoksndreg.tasks.reg2d_features
from typing import Optional
from typing import Union
import numpy
from ewokscore.model import BaseInputModel
from ewokscore.task import Task
from pydantic import Field
from ..features import registration
from ..io.input_stack import InputDataType
from ..io.input_stack import input_context
from ..registry import RegistryIdType
from ..transformation.types import TransformationType
[docs]
class Inputs(BaseInputModel):
image_stacks: InputDataType = Field(
...,
description="Image stacks as a dictionary of numpy arrays or list of HDF5 dataset URI's.",
examples=[
{
"stack1": "/path/to/file.h5::/entry/process/results/parameters/Ca-K",
"stack2": "/path/to/file.h5::/entry/process/results/parameters/Fe-K",
},
{"stack1": [[0, 0, 0], [1, 1, 1], [2, 2, 2]]},
],
)
detector: RegistryIdType = Field(
...,
description="Method to detect and describe feature points in an image.",
examples=["Sift-Silx", "Sift-SciKitImage", "Orb-SciKitImage"],
)
matcher: RegistryIdType = Field(
...,
description="Method to build correspondence between the feature points in two images.",
examples=["Descriptor-Silx", "Descriptor-SciKitImage"],
)
mapper: RegistryIdType = Field(
...,
description="Method to find parameters of the transformation between the matches.",
examples=["LstSq-Numpy", "LstSq-SciPy", "Ransac-SciKitImage"],
)
transformation_type: TransformationType = Field(
...,
description="Type of transformation between the matches.",
examples=["Translation", "Rigid", "Affine"],
)
reference_image: Union[int, float] = Field(
0,
description="The index of the reference image in the stack (0.5 is the middle of the stack)."
"The calculated transformations will be relative to this image.",
examples=[0, -1, 0.5],
)
reference_stack: Optional[str] = Field(
None,
description="Transformations of all stacks is based on the image registration of this stack.",
examples=["stack1", "stack2"],
)
mask: Optional[numpy.ndarray] = Field(
None,
description="Boolean image mask applied to the image before calculating the transformation (False means masked-off).",
examples=[[[True, True, True], [True, True, True], [False, True, True]]],
)
output_configuration: Optional[dict] = Field(
None,
description="Registration configuration parameters to be saved.",
examples=[{"param1": 0, "param2": 1}],
)
[docs]
class Reg2DFeatures(
Task,
input_model=Inputs,
output_names=[
"image_stacks",
"transformations",
"reference_stack",
"features",
"matches",
"output_configuration",
],
):
"""Use a feature-based registration method to calculate transformations between the images in one or more stacks."""
[docs]
def run(self):
detector = registration.FeatureDetector.get_subclass(self.inputs.detector)(
mask=self.inputs.mask
)
matcher = registration.FeatureMatching.get_subclass(self.inputs.matcher)()
mapper = registration.FeatureMapping.get_subclass(self.inputs.mapper)(
self.inputs.transformation_type
)
stacks_to_align = self.inputs.image_stacks
reference_stack = self.inputs.reference_stack
if reference_stack:
if reference_stack not in stacks_to_align:
raise ValueError(
f"{reference_stack=} must be in {list(stacks_to_align)}"
)
stacks_to_align = {reference_stack: stacks_to_align[reference_stack]}
with input_context(stacks_to_align) as stacks:
features = registration.detect_features(stacks, detector)
matches = registration.match_features(
stacks,
features,
matcher,
reference_image=self.inputs.reference_image,
)
transformations = registration.transformations_from_features(
matches, mapper
)
if reference_stack:
names = list(self.inputs.image_stacks)
transformations = {name: transformations[reference_stack] for name in names}
features = {name: features[reference_stack] for name in names}
matches = {name: matches[reference_stack] for name in names}
self.outputs.transformations = transformations
self.outputs.reference_stack = reference_stack
self.outputs.features = features
self.outputs.matches = matches
self.outputs.image_stacks = self.inputs.image_stacks
output_configuration = self.get_input_value("output_configuration") or dict()
output_configuration["detector"] = str(detector.get_subclass_id())
output_configuration["matcher"] = str(matcher.get_subclass_id())
output_configuration["mapper"] = str(mapper.get_subclass_id())
output_configuration["transformation_type"] = mapper.transformation_type.value
output_configuration["reference_image"] = self.inputs.reference_image
output_configuration["reference_stack"] = reference_stack
self.outputs.output_configuration = output_configuration