Skip to content

Useful JAX functions¤

cryojax provides a collection of useful functions not found in JAX that tend to be important for tasks in cryo-EM. Depending on developments with core JAX/Equinox and other factors, these functions could be removed in future releases of cryojax. Use with caution!

cryojax.utils.batched_map(f: Callable[[PyTree[Shaped[Array, '_ ...'], X]], PyTree[Shaped[Array, '_ ...'], Y]], xs: PyTree[Shaped[Array, '_ ...'], X], *, batch_size: int = 1) -> PyTree[Shaped[Array, '_ ...'], Y] ¤

Like jax.lax.map(..., batch_size=...), except f(x) is already vmapped by the user. In particular, it must be vmapped over the first axis of the arrays of x.

Arguments:

  • f: As jax.lax.map with format f(x), except vmapped over the first axis of the arrays of x.
  • xs: As jax.lax.map.
  • batch_size: Compute a loop of vmaps over xs in chunks of batch_size.

Returns:

As jax.lax.map.

cryojax.utils.batched_scan(f: Callable[[Carry, PyTree[Shaped[Array, '_ ...'], X]], tuple[Carry, PyTree[Shaped[Array, '_ ...'], Y]]], init: Carry, xs: PyTree[Shaped[Array, '_ ...'], X], length: int | None = None, unroll: int | bool = 1, *, batch_size: int = 1) -> tuple[Carry, PyTree[Shaped[Array, '_ ...'], Y]] ¤

Like jax.lax.map(..., batch_size=...), except adding a batch_size to jax.lax.scan. Additionally, unlike jax.lax.map, it is assumed that f(carry, x) is already vmapped over the first axis of the arrays of x.

Arguments:

  • f: As jax.lax.scan with format f(carry, x), except vmapped over the first axis of the arrays of x.
  • init: As jax.lax.scan.
  • xs: As jax.lax.scan.
  • length: As jax.lax.scan.
  • unroll: As jax.lax.scan.
  • batch_size: Compute a loop of vmaps over xs in chunks of batch_size.

Returns:

As jax.lax.scan.