Source code for ewoksndreg.intensities.torch_metrics
import torch
import torch.nn.functional as F
[docs]
def nmi_loss(
source: torch.Tensor, target: torch.Tensor, bandwidth: int = 0.01, reduction="none"
):
"""
Calculates the Normalized Mutual Information between batch source and batch target
param source: Batch of Tensors
param target: Batch of Tensors with same size as source
param bandwidth: The size of the kernel function, determines the sample rate of the histogramm(default: 0.01)
param reduction: If None: returns Tensor of NMI for every element in batch. If Sum or Mean: Applies operation to the resulting tensor
"""
eps = 1e-10
m = source.flatten(start_dim=1)
t = target.flatten(start_dim=1)
batch = m.size()[0]
bins = round(0.5 / bandwidth)
samples = torch.linspace(0, 1, bins)
# samples has size [bins]
samples = samples.unsqueeze(0).expand((batch, -1))
# m has size [N,items], samples has size [N, bins]
m_outer = m[..., None] - samples[:, None, ...]
m_outer = quartic_kernel(m_outer / bandwidth)
t_outer = t[:, None, ...] - samples[..., None]
t_outer = quartic_kernel(t_outer / bandwidth)
# m_outer and t_outer have shape [N,bins, items] and [N,items,bins]
# result has dimensions[N,bins,bins]
joint_hgram = torch.matmul(t_outer, m_outer)
joint_hgram = F.normalize(joint_hgram, p=1, dim=(1, 2))
m_hgram = torch.sum(joint_hgram, dim=1)
t_hgram = torch.sum(joint_hgram, dim=2)
m_hgram[m_hgram <= eps] = eps
t_hgram[t_hgram <= eps] = eps
joint_hgram[joint_hgram <= eps] = eps
m_ent = (m_hgram * torch.log2(m_hgram)).nansum(dim=1)
t_ent = (t_hgram * torch.log2(t_hgram)).nansum(dim=1)
joint_ent = (joint_hgram * torch.log2(joint_hgram)).nansum(dim=(1, 2))
mut_inf = -((m_ent + t_ent) / joint_ent)
if reduction == "none":
return mut_inf
elif reduction == "mean":
return mut_inf.mean()
elif reduction == "sum":
return mut_inf.sum()
else:
raise ValueError(f"reduction must either be none, mean or sum. Got {reduction}")
[docs]
def quartic_kernel(data: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(data)
result[torch.logical_and(data < 1, data > -1)] = (
15 / 16 * (1 - data[torch.logical_and(data < 1, data > -1)] ** 2) ** 2
)
return result
[docs]
def block_kernel(data: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(data)
result[torch.logical_and(data < 1, data > -1)] = 1
return result