Source code for mrpro.utils.smap

"""Smap utility function."""

from collections.abc import Callable, Sequence

import torch


[docs] def smap( function: Callable[[torch.Tensor], torch.Tensor], tensor: torch.Tensor, passed_dimensions: Sequence[int] | int = (-1,), ) -> torch.Tensor: """Apply a function to a tensor serially along multiple dimensions. The function is applied serially without a batch dimensions. Compared to torch.vmap, it works with arbitrary functions, but is slower. Parameters ---------- function Function to apply to the tensor. Should handle len(fun_dims) dimensions and not change the number of dimensions. tensor Tensor to apply the function to. passed_dimensions Dimensions NOT to be batched / dimensions that are passed to the function tuple of dimension indices (negative indices are supported) or an integer an integer n means the last n dimensions are passed to the function """ if isinstance(passed_dimensions, int): # use the last fun_dims dimensions for the function moved = tensor first_fun_dim = -passed_dimensions else: # Move fun_dims to the end fun_dims_dst = tuple(range(-len(passed_dimensions), 0)) moved = tensor.moveaxis(tuple(passed_dimensions), fun_dims_dst) first_fun_dim = fun_dims_dst[0] reshaped = moved.flatten(end_dim=first_fun_dim - 1) # shape: (prod(batch_dims), fun_dim_1, ..., fun_dim_n) result_reshaped = torch.stack([function(x) for x in reshaped]) result = result_reshaped.reshape(moved.shape[:first_fun_dim] + result_reshaped.shape[1:]) if not isinstance(passed_dimensions, int): # Move fun_dims back to their original position if we moved them result = result.moveaxis(fun_dims_dst, tuple(passed_dimensions)) return result