mrpro.utils.smap
- mrpro.utils.smap(function: Callable[[Tensor], Tensor], tensor: Tensor, passed_dimensions: Sequence[int] | int = (-1,)) Tensor [source]
Apply a function to a tensor serially along multiple dimensions.
The function is applied serially without batch dimensions. Compared to
torch.vmap
, it works with arbitrary functions, but is slower.- Parameters:
function (
Callable
[[Tensor
],Tensor
]) – Function to apply to the tensor. Should handlelen(fun_dims)
dimensions and not change the number of dimensions.tensor (
Tensor
) – Tensor to apply the function to.passed_dimensions (
Sequence
[int
] |int
, default:(-1,)
) – Dimensions NOT to be batched / dimensions that are passed to the function. Either a tuple of dimension indices (negative indices are supported) or an integer. An integern
means the lastn
dimensions are passed to the function.