mrpro.utils.smap

mrpro.utils.smap(function: Callable[[Tensor], Tensor], tensor: Tensor, passed_dimensions: Sequence[int] | int = (-1,)) Tensor[source]

Apply a function to a tensor serially along multiple dimensions.

The function is applied serially without a batch dimensions. Compared to torch.vmap, it works with arbitrary functions, but is slower.

Parameters:
  • function – Function to apply to the tensor. Should handle len(fun_dims) dimensions and not change the number of dimensions.

  • tensor – Tensor to apply the function to.

  • passed_dimensions – Dimensions NOT to be batched / dimensions that are passed to the function tuple of dimension indices (negative indices are supported) or an integer an integer n means the last n dimensions are passed to the function