Skip to content

Helper functions for Equinox filtering¤

To make use of the full power of JAX, it is highly recommended to learn about equinox. Using equinox, cryoJAX implements its models as pytrees using equinox.Modules. These pytrees can be operated on similarly to any pytree with JAX (e.g. with jax.tree.map). Complementary to the equinox.Module interface, equinox introduces the idea of filtering in order to separate pytree leaves into different groups, a core task in using JAX (e.g. separating traced and static arguments to jax.jit). In particular, this grouping is achieved with the functions eqx.partition and eqx.combine. This documentation describes utilities in cryojax for working with equinox.partition and equinox.combine.

cryojax.utils.get_filter_spec(pytree: PyTree, where: Callable[[PyTree], Union[Any, Sequence[Any]]], *, inverse: bool = False, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree[bool] ¤

A lightweight wrapper around equinox for creating a "filter specification".

A filter specification, or filter_spec, is a pytree whose leaves are either True or False. These are commonly used with equinox filtering.

In cryojax, it is a common pattern to need to finely specify which leaves we would like to take JAX transformations with respect to. This is done with a pointer to individual leaves, which is referred to as a where function. See here in the equinox documentation for an example.

Returns:

The filter specification. This is a pytree of the same structure as pytree with True where the where function points to, and False where it does not (or the opposite, if inverse = True).