Useful JAX functions¤
cryojax
provides a collection of useful functions not found in JAX that tend to be important for tasks in cryo-EM. Depending on developments with core JAX/Equinox and other factors, these functions could be removed in future releases of cryojax
. Use with caution!
cryojax.utils.batched_map(f: Callable[[PyTree[Shaped[Array, '_ ...'], X]], PyTree[Shaped[Array, '_ ...'], Y]], xs: PyTree[Shaped[Array, '_ ...'], X], *, batch_size: int = 1) -> PyTree[Shaped[Array, '_ ...'], Y]
¤
Like jax.lax.map(..., batch_size=...)
, except
f(x)
is already vmapped by the user. In particular,
it must be vmapped over the first axis of the arrays of x
.
Arguments:
f
: Asjax.lax.map
with formatf(x)
, except vmapped over the first axis of the arrays ofx
.xs
: Asjax.lax.map
.batch_size
: Compute a loop of vmaps overxs
in chunks ofbatch_size
.
Returns:
As jax.lax.map
.
cryojax.utils.batched_scan(f: Callable[[Carry, PyTree[Shaped[Array, '_ ...'], X]], tuple[Carry, PyTree[Shaped[Array, '_ ...'], Y]]], init: Carry, xs: PyTree[Shaped[Array, '_ ...'], X], length: int | None = None, unroll: int | bool = 1, *, batch_size: int = 1) -> tuple[Carry, PyTree[Shaped[Array, '_ ...'], Y]]
¤
Like jax.lax.map(..., batch_size=...)
, except adding
a batch_size
to jax.lax.scan
. Additionally, unlike
jax.lax.map
, it is assumed that f(carry, x)
is already
vmapped over the first axis of the arrays of x
.
Arguments:
f
: Asjax.lax.scan
with formatf(carry, x)
, except vmapped over the first axis of the arrays ofx
.init
: Asjax.lax.scan
.xs
: Asjax.lax.scan
.length
: Asjax.lax.scan
.unroll
: Asjax.lax.scan
.batch_size
: Compute a loop of vmaps overxs
in chunks ofbatch_size
.
Returns:
As jax.lax.scan
.