from typing import Optional
import numpy
from typing_extensions import Literal
from ..dependencies.sitk import sitk
from ..transformation.simpleitk_backend import SimpleITKTransformation
from ..transformation.types import TransformationType
from .base import IntensityMapping
from .types import SitkMetricType
from .types import SitkOptimizerType
[docs]
class SitkOptimizerIntensityMapping(
IntensityMapping,
registry_id=IntensityMapping.RegistryId("Optimization", "SimpleITK"),
):
SUPPORTED_TRANSFORMATIONS = [
"translation",
"rigid",
"similarity",
"affine",
"bspline",
"displacement_field",
]
def __init__(
self,
transfo_type: TransformationType,
metric: SitkMetricType = SitkMetricType.MeanSquares,
optimizer: SitkOptimizerType = SitkOptimizerType.LBFGSB,
order: int = 1,
mask: Optional[numpy.ndarray] = None,
pyramid_levels: int = 1,
mesh_size: Optional[tuple] = None,
sampling: Optional[Literal["random", "regular"]] = None,
sampling_percentage: int = 0.5,
**kw,
) -> None:
self._registration_method = sitk.ImageRegistrationMethod()
self._metric = SitkMetricType(metric)
self._optimizer = SitkOptimizerType(optimizer)
self._mask = mask
self._mesh_size = mesh_size
self._levels = pyramid_levels
if order == 0:
self._registration_method.SetInterpolator(sitk.sitkNearestNeighbor)
elif order == 1:
self._registration_method.SetInterpolator(sitk.sitkLinear)
elif order == 2:
self._registration_method.SetInterpolator(sitk.sitkBSpline1)
elif order == 3:
self._registration_method.SetInterpolator(sitk.sitkBSpline2)
elif order == 4:
self._registration_method.SetInterpolator(sitk.sitkBSpline3)
elif order == 5:
self._registration_method.SetInterpolator(sitk.sitkBSpline4)
self._order = order
super().__init__(transfo_type, **kw)
# set the metric
if self.transformation_type == self.transformation_type.displacement_field:
self._registration_method.SetMetricAsDemons(0.00001)
elif self._metric == self._metric.CrossCorrelation:
self._registration_method.SetMetricAsCorrelation()
elif self._metric == self._metric.ANTSNeighborhoodCorrelation:
self._registration_method.SetMetricAsANTSNeighborhoodCorrelation(4)
elif self._metric == self._metric.JointHistogramMutualInformation:
self._registration_method.SetMetricAsJointHistogramMutualInformation()
elif self._metric == self._metric.MattesMutualInformation:
self._registration_method.SetMetricAsMattesMutualInformation()
elif self._metric == self._metric.MeanSquares:
self._registration_method.SetMetricAsMeanSquares()
else:
raise ValueError(
f"metric {self._metric} not supported by SimpleITK, choose from {list(SitkMetricType.__members__)}"
)
# set the optimizer
if self._optimizer == self._optimizer.Exhaustive:
self._registration_method.SetOptimizerAsExhaustive((10, 10))
elif self._optimizer == self._optimizer.Powell:
self._registration_method.SetOptimizerAsPowell()
elif self._optimizer == self._optimizer.Evolutionary:
self._registration_method.SetOptimizerAsOnePlusOneEvolutionary()
elif self._optimizer == self._optimizer.LBFGSB:
self._registration_method.SetOptimizerAsLBFGSB()
elif self._optimizer == self._optimizer.GradientDescent:
self._registration_method.SetOptimizerAsGradientDescent(
learningRate=1.0,
numberOfIterations=300,
convergenceMinimumValue=1e-6,
convergenceWindowSize=10,
)
elif self._optimizer == self._optimizer.CGLineSearch:
self._registration_method.SetOptimizerAsConjugateGradientLineSearch(
2.0, 300
)
elif self._optimizer == self._optimizer.GradientDescentLineSearch:
self._registration_method.SetOptimizerAsGradientDescentLineSearch(2.0, 300)
else:
raise ValueError(
f"optimizer {self._optimizer} not supported by SimpleITK, choose from {[str(item) for item in list(SitkOptimizerType)]}"
)
# set sampling strategy
if sampling == "random":
self._registration_method.SetMetricSamplingStrategy(
self._registration_method.RANDOM
)
self._registration_method.SetMetricSamplingPercentage(sampling_percentage)
elif sampling == "regular":
self._registration_method.SetMetricSamplingStrategy(
self._registration_method.REGULAR
)
self._registration_method.SetMetricSamplingPercentage(sampling_percentage)
[docs]
def identity(self, dimension: int = 2) -> SimpleITKTransformation:
return SimpleITKTransformation(
transfo_type="identity", passive_matrix=numpy.identity(dimension + 1)
)
[docs]
def calculate(
self,
from_image: numpy.ndarray,
to_image: numpy.ndarray,
) -> SimpleITKTransformation:
if self._mask is not None and numpy.any(self._mask):
mask = self._mask.astype(bool).astype("uint8")
mask = sitk.GetImageFromArray(mask)
self._registration_method.SetMetricMovingMask(mask)
self._registration_method.SetMetricFixedMask(mask)
from_imageitk = sitk.GetImageFromArray(from_image)
to_imageitk = sitk.GetImageFromArray(to_image)
dim = from_imageitk.GetDimension()
if self.transformation_type == self.transformation_type.translation:
tx = sitk.TranslationTransform(dim, dim * [0])
elif self._transfo_type in ["rigid", "similarity", "affine"]:
if self.transformation_type == self.transformation_type.rigid and dim == 2:
tf = sitk.Euler2DTransform()
elif (
self.transformation_type == self.transformation_type.rigid and dim == 3
):
tf = sitk.Euler3DTransform()
elif (
self.transformation_type == self.transformation_type.similarity
and dim == 2
):
tf = sitk.Similarity2DTransform()
elif (
self.transformation_type == self.transformation_type.similarity
and dim == 3
):
tf = sitk.Similarity3DTransform()
elif self.transformation_type == self.transformation_type.affine:
tf = sitk.AffineTransform(dim)
tx = sitk.CenteredTransformInitializer(
to_imageitk,
from_imageitk,
tf,
sitk.CenteredTransformInitializerFilter.GEOMETRY,
)
elif self.transformation_type == self.transformation_type.bspline:
# Use given mesh size or put control point every 20 pixels
if self._mesh_size:
tx = sitk.BSplineTransformInitializer(from_imageitk, self._mesh_size)
else:
tx = sitk.BSplineTransformInitializer(
from_imageitk, [int(size / 20) for size in from_imageitk.GetSize()]
)
elif self.transformation_type == self.transformation_type.displacement_field:
img = sitk.GetImageFromArray(numpy.zeros((*from_image.shape, dim)), True)
tx = sitk.DisplacementFieldTransform(img)
tx.SetSmoothingGaussianOnUpdate(
varianceForUpdateField=0.0, varianceForTotalField=2.0
)
else:
raise ValueError(f"{self._transfo_type} is not supported by SimpleITK")
self._registration_method.SetInitialTransform(tx)
if self._optimizer not in ["Powell", "GradientDescent"]:
factors = [int(2**x) for x in numpy.arange(self._levels - 1, -1, -1)]
self._registration_method.SetShrinkFactorsPerLevel(shrinkFactors=factors)
self._registration_method.SetSmoothingSigmasPerLevel(
smoothingSigmas=numpy.arange(
self._levels - 1, -1, -1, dtype=numpy.float64
)
)
if self._optimizer != self._optimizer.LBFGSB:
self._registration_method.SetOptimizerScalesFromPhysicalShift()
result = self._registration_method.Execute(to_imageitk, from_imageitk)
return SimpleITKTransformation(
transfo_type=self._transfo_type,
transformation=result,
)