Helper functions for Equinox filtering¤
To make use of the full power of JAX, it is highly recommended to learn about equinox. Using equinox
, cryoJAX implements its models as pytrees using equinox.Module
s. These pytrees can be operated on similarly to any pytree with JAX (e.g. with jax.tree.map
). Complementary to the equinox.Module
interface, equinox
introduces the idea of filtering in order to separate pytree leaves into different groups, a core task in using JAX (e.g. separating traced and static arguments to jax.jit
). In particular, this grouping is achieved with the functions eqx.partition and eqx.combine. This documentation describes utilities in cryojax
for working with equinox.partition
and equinox.combine
.
cryojax.utils.get_filter_spec(pytree: PyTree, where: Callable[[PyTree], Union[Any, Sequence[Any]]], *, inverse: bool = False, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree[bool]
¤
A lightweight wrapper around equinox
for creating a "filter specification".
A filter specification, or filter_spec
, is a pytree whose
leaves are either True
or False
. These are commonly used with
equinox
filtering.
In cryojax
, it is a common pattern to need to finely specify which
leaves we would like to take JAX transformations with respect to. This is done with a
pointer to individual leaves, which is referred to as a where
function. See
here
in the equinox
documentation for an example.
Returns:
The filter specification. This is a pytree of the same structure as pytree
with
True
where the where
function points to, and False
where it does not
(or the opposite, if inverse = True
).