Source code for brainlit.map_neurons.diffeo_gen

import numpy as np
import torch
import matplotlib.pyplot as plt


def interp(x, I, phii, **kwargs):
    """Interpolate a function given original and transformed coordinates.

    Args:
        x (list): Original pixel locations of image.
        I (torch.tensor): Function to be interpolated.
        phii (torch.tensor): Transformed coordinates.

    Raises:
        Exception: Function should be four or five dimensional.

    Returns:
        torch.tensor: Interpolated function.
    """
    # note I want components of phi to be at the end
    # components of I should be at the beginning
    # this does make composing two transformations a bit weird
    # start by scaling to -1,1
    phii_ = torch.clone(phii)
    for i in range(3):
        phii_[..., i] -= x[i][0]
        phii_[..., i] /= x[i][-1] - x[i][0]
    phii_ *= 2.0
    phii_ -= 1.0

    # we need xyz at the end, and in the order xyz (not zyx)
    # check if I'm using batches
    if I.ndim == 4:
        add_batch = True
    elif I.ndim == 5:
        add_batch = False
    else:
        raise Exception("Image should be 4 or 5 dim")

    if add_batch:
        Iin = I[None]
        phii_in = phii_[None]
    else:
        Iin = I
        phii_in = phii_
    output = torch.nn.functional.grid_sample(
        Iin,
        torch.flip(phii_in, (-1,)),
        align_corners=True,
        padding_mode="border",
        **kwargs
    )
    # remove batch dimension
    if add_batch:
        output = output[0]
    return output


def expR(xv, v0, K, n=10, visualize=False, return_forward=True):
    """Riemannian exponential

    Args:
        xv (list of arrays): Location of pixels in v.
        v0 (array): velocity at time 0. Recall shape is rowxcolxsicex3.
        K (array): kernel in fft domain
        n (int, optional): number of timesteps. Defaults to 10.
        visualize (bool, optional): Whether to plot the output. Defaults to False.
        return_forward (bool, optional): Direction of exponential. Defaults to True.

    Returns:
        torch.tensor: Generated diffeomorphism.
    """

    use_batch = v0.ndim == 5
    if not use_batch:
        permute0 = (-1, -4, -3, -2)
        permute1 = (-3, -2, -1, -4)
    else:
        permute0 = (0, -1, -4, -3, -2)
        permute1 = (0, -3, -2, -1, -4)

    # initialize p at time 0
    p0 = torch.fft.ifftn(
        torch.fft.fftn(v0, dim=(-2, -3, -4)) / K[..., None], dim=(-2, -3, -4)
    ).real
    # initialize phii at time 0
    XV = torch.stack(
        torch.meshgrid([torch.as_tensor(x) for x in xv], indexing="ij"), -1
    )
    # initialize dv
    dv = [x[1].item() - x[0].item() for x in xv]

    phii = XV.clone()
    if use_batch:
        phii = phii[None].repeat(v0.shape[0], 1, 1, 1, 1)
        XV = XV[None].repeat(v0.shape[0], 1, 1, 1, 1)

    if visualize and not use_batch:
        fig, ax = plt.subplots(1, 3)

    # we take n timesteps

    if return_forward:
        vsave = []
    for t in range(n):
        # we need to calculate p at time t
        # first we just deform it
        p = interp(xv, p0.permute(*permute0), phii).permute(*permute1)
        # then we need the jacobian
        Dphii = torch.stack(torch.gradient(phii, dim=(-4, -3, -2), spacing=dv), -1)
        # and the determinant (over the last two axes)
        detDphii = torch.linalg.det(Dphii)
        # then we will multiply

        p = (Dphii.transpose(-1, -2) @ p[..., None])[..., 0] * detDphii[..., None]
        # now we calculate v
        v = torch.fft.ifftn(
            torch.fft.fftn(p, dim=(-2, -3, -4)) * K[..., None], dim=(-2, -3, -4)
        ).real
        if return_forward:
            vsave.append(v)
        # now we update phii
        Xs = XV - v / n
        phii = interp(xv, (phii - XV).permute(*permute0), Xs).permute(*permute1) + Xs

        if visualize and not use_batch:
            pshow = np.array(p[p0.shape[1] // 2, :, :, :])
            pshow -= np.min(pshow, axis=(0, 1, 2))
            pshow /= np.max(pshow, axis=(0, 1, 2))
            ax[0].cla()
            ax[0].imshow(pshow)

            vshow = np.array(v[p0.shape[1] // 2, :, :, :])
            vshow -= np.min(vshow, axis=(0, 1, 2))
            vshow /= np.max(vshow, axis=(0, 1, 2))
            ax[1].cla()
            ax[1].imshow(vshow)

            fig.canvas.draw()
    if not return_forward:
        return phii
    else:
        phi = XV.clone()
        for v in reversed(vsave):
            Xs = XV + v / n
            phi = interp(xv, (phii - XV).permute(*permute0), Xs).permute(*permute1) + Xs
        return phi


[docs]def diffeo_gen_ara(sigma): """Return random diffeomorphism generated by sampling Gaussian noise then passing through Riemannian exponential. Args: sigma (float): standard deviation of noise in microns. Returns: List: list of sampled points in spatial domain. np.array: range of diffeomorphism at sampled points. """ # a domain for sampling your velocity and deformatoin dv = np.array([100.0, 100.0, 100.0]) # units are every 100 microns nv = np.array([132, 80, 114]) xv = [np.arange(n) * d - (n - 1) * d / 2 for n, d in zip(nv, dv)] XV = torch.stack( torch.meshgrid([torch.as_tensor(x) for x in xv], indexing="ij"), -1 ) # a frequency domain fv = [np.arange(n) / n / d for n, d in zip(nv, dv)] FV = np.stack(np.meshgrid(*fv, indexing="ij"), -1) a = 100.0 p = 2.0 LL = ( 1.0 - 2.0 * a**2 * np.sum(((np.cos(2.0 * np.pi * FV * dv) - 1)) / dv**2, -1) ) ** (2 * p) K = 1.0 / LL # lets make a new p which is really simple for testing # sample white noise Lm = np.random.randn(*FV.shape) * sigma # smooth it with sqrt(K) (here I smoothed with K to be a bit smoother) v = np.fft.ifftn( np.fft.fftn(Lm, axes=(0, 1, 2)) * K[..., None], axes=(0, 1, 2) ).real # shoot it with remannian exponential phii = expR([torch.tensor(x) for x in xv], torch.tensor(v), K, n=10) phii = phii.detach().cpu().numpy() return xv, phii