mrpro.utils.smap
- mrpro.utils.smap(function: Callable[[Tensor], Tensor], tensor: Tensor, passed_dimensions: Sequence[int] | int = (-1,)) Tensor [source]
Apply a function to a tensor serially along multiple dimensions.
The function is applied serially without a batch dimensions. Compared to torch.vmap, it works with arbitrary functions, but is slower.
- Parameters:
function – Function to apply to the tensor. Should handle len(fun_dims) dimensions and not change the number of dimensions.
tensor – Tensor to apply the function to.
passed_dimensions – Dimensions NOT to be batched / dimensions that are passed to the function tuple of dimension indices (negative indices are supported) or an integer an integer n means the last n dimensions are passed to the function