mrpro.utils.smap

mrpro.utils.smap(function: Callable[[Tensor], Tensor], tensor: Tensor, passed_dimensions: Sequence[int] | int = (-1,)) Tensor[source]

Apply a function to a tensor serially along multiple dimensions.

The function is applied serially without batch dimensions. Compared to torch.vmap, it works with arbitrary functions, but is slower.

Parameters:
  • function (Callable[[Tensor], Tensor]) – Function to apply to the tensor. Should handle len(fun_dims) dimensions and not change the number of dimensions.

  • tensor (Tensor) – Tensor to apply the function to.

  • passed_dimensions (Sequence[int] | int, default: (-1,)) – Dimensions NOT to be batched / dimensions that are passed to the function. Either a tuple of dimension indices (negative indices are supported) or an integer. An integer n means the last n dimensions are passed to the function.