from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy
from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from ewokscore.task import Task
from pydantic import Field
from ..features import registration
from ..features.features.base import Features
from ..io.input_stack import InputDataType
from ..io.input_stack import input_context
from ..registry import RegistryIdType
from ..transformation.base import Transformation
from ..transformation.types import TransformationType
[docs]
class Outputs(BaseOutputModel):
image_stacks: InputDataType = Field(
description="Dictionary of image stacks in memory or URIs."
)
transformations: Dict[str, List[Transformation]] = Field(
description="Transformation between the images of each stack."
)
reference_stack: Optional[str] = Field(
description="Transformations of all stacks is based on the image registration of this stack.",
examples=["stack1", "stack2"],
)
features: Dict[str, List[Features]]
matches: Dict[str, List[Tuple[Optional[Features], Optional[Features]]]]
output_configuration: Optional[Dict[str, Any]] = Field(
description="Registration configuration parameters to be saved.",
examples=[{"param1": 0, "param2": 1}],
)
[docs]
class Reg2DFeatures(Task, input_model=Inputs, output_model=Outputs):
"""Use a feature-based registration method to calculate transformations between the images in one or more stacks."""
[docs]
def run(self):
detector = registration.FeatureDetector.get_subclass(self.inputs.detector)(
mask=self.inputs.mask
)
matcher = registration.FeatureMatching.get_subclass(self.inputs.matcher)()
mapper = registration.FeatureMapping.get_subclass(self.inputs.mapper)(
self.inputs.transformation_type
)
stacks_to_align = self.inputs.image_stacks
reference_stack = self.inputs.reference_stack
if reference_stack:
if reference_stack not in stacks_to_align:
raise ValueError(
f"{reference_stack=} must be in {list(stacks_to_align)}"
)
stacks_to_align = {reference_stack: stacks_to_align[reference_stack]}
with input_context(stacks_to_align) as stacks:
features = registration.detect_features(stacks, detector)
matches = registration.match_features(
stacks,
features,
matcher,
reference_image=self.inputs.reference_image,
)
transformations = registration.transformations_from_features(
matches, mapper
)
if reference_stack:
names = list(self.inputs.image_stacks)
transformations = {name: transformations[reference_stack] for name in names}
features = {name: features[reference_stack] for name in names}
matches = {name: matches[reference_stack] for name in names}
self.outputs.transformations = transformations
self.outputs.reference_stack = reference_stack
self.outputs.features = features
self.outputs.matches = matches
self.outputs.image_stacks = self.inputs.image_stacks
output_configuration = self.get_input_value("output_configuration") or dict()
output_configuration["detector"] = str(detector.get_subclass_id())
output_configuration["matcher"] = str(matcher.get_subclass_id())
output_configuration["mapper"] = str(mapper.get_subclass_id())
output_configuration["transformation_type"] = mapper.transformation_type.value
output_configuration["reference_image"] = self.inputs.reference_image
output_configuration["reference_stack"] = reference_stack
self.outputs.output_configuration = output_configuration