Source code for ewoksndreg.tests.test_nmi_loss

import pytest

from ewoksndreg.io import data_for_registration

try:
    import torch

    from ewoksndreg.intensities.torch_metrics import nmi_loss
except ImportError:
    torch = None
import numpy


[docs] @pytest.mark.skipif(torch is None, reason="python version not supported by torch") def test_rand_nmi(): rand = torch.rand((3, 1, 20, 20)) rand2 = torch.rand((3, 1, 20, 20)) mixed = torch.clone(rand2) mixed[..., :10] = rand[..., :10] same = nmi_loss(rand, rand) half = nmi_loss(rand, mixed) random = nmi_loss(rand, rand2) torch.testing.assert_close(same < half, torch.ones((3), dtype=torch.bool)) torch.testing.assert_close(half < random, torch.ones((3), dtype=torch.bool))
[docs] @pytest.mark.skipif(torch is None, reason="python version not supported by torch") def test_img_nmi(): image = data_for_registration.generate_image() images, _, _ = data_for_registration.generate_image_stack( image, "translation", nimages=8 ) images = torch.tensor(numpy.array(images)) images = images[:, None, ...] losses = [nmi_loss(images[0:1], images[i : i + 1]) for i in range(len(images - 1))] assert all(losses[i] < losses[i + 1] for i in range(len(losses) - 1))