JAX transformations
This tutorial demonstrates how to get started with JAX transformations in cryoJAX.
# Plotting imports and function definitions
import math
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_image_stack(images, cmap="gray", **kwargs):
n_images_per_side = int(math.sqrt(images.shape[0]))
fig, axes = plt.subplots(nrows=n_images_per_side, ncols=n_images_per_side)
vmin, vmax = images.min(), images.max()
for idx, ax in enumerate(axes.ravel()):
im = ax.imshow(
images[idx], cmap=cmap, vmin=vmin, vmax=vmax, origin="lower", **kwargs
)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax)
fig.tight_layout()
# Start by creating an image model
import cryojax.simulator as cxs
from cryojax.io import read_array_from_mrc
# Scattering potential stored in MRC format
filename = "./data/groel_5w0s_scattering_potential.mrc"
# ... read into a FourierVoxelGridPotential
real_voxel_grid, voxel_size = read_array_from_mrc(filename, loads_spacing=True)
potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(
real_voxel_grid, voxel_size, pad_scale=2
)
# Now, instantiate the pose. Angles are given in degrees
pose = cxs.EulerAnglePose(
offset_x_in_angstroms=5.0,
offset_y_in_angstroms=-3.0,
phi_angle=20.0,
theta_angle=80.0,
psi_angle=-5.0,
)
# First, the contrast transfer theory
ctf = cxs.AberratedAstigmaticCTF(
defocus_in_angstroms=10000.0,
astigmatism_in_angstroms=-100.0,
astigmatism_angle=10.0,
)
transfer_theory = cxs.ContrastTransferTheory(ctf, amplitude_contrast_ratio=0.1)
# Then the configuration. Add padding with respect to the final image shape.
pad_options = dict(shape=potential.shape[0:2])
config = cxs.BasicConfig(
shape=(80, 80),
pixel_size=potential.voxel_size,
voltage_in_kilovolts=300.0,
pad_options=pad_options,
)
# Make the image model. By default, cryoJAX will simulate
# the contrast in physical units. Rather, normalize the image.
image_model = cxs.make_image_model(
potential,
config,
pose,
transfer_theory,
normalizes_signal=True,
)
In this tutorial, we will show how to perform transformations using jax.jit
, jax.grad
, and jax.vmap
. To do this, we will leverage filtered transformations in equinox
. Start by defining a function that simulates images.
import equinox as eqx
@eqx.filter_jit
def simulate_fn(image_model):
return image_model.simulate()
What's with the eqx.filter_jit?
This is an example of a JAX transformation for JIT compilation (i.e. jax.jit
). If you aren't familar with jax.jit
, then start by reading the JAX documentation.
In particular, eqx.filter_jit
a filtered transformation through the package equinox
. In summary, this is a lightweight wrapper around jax.jit
that treats all of the image_model
's JAX arrays as traced at compile time, and all of its non-JAX arrays as static. If you aren't familiar with static arguments in JAX, see here.
Filtered transformations are a key feature of equinox
and it is highly recommended to learn about them. See here in the equinox documentation for an introduction. Note that these transformations are completely optional; the jax.jit
decorator can also be used if you are willing to write a few more lines of code.
# Compare speed with and without JIT
print("Simulate an image without JIT")
%timeit image_model.simulate().block_until_ready()
print("Simulate an image with JIT")
%timeit simulate_fn(image_model).block_until_ready()
Indeed there is a large speedup!
Next, let's compute our first gradient. If you haven't yet read about filtered transformations in equinox
, stop and do this now. In particular, read about equinox.partition
and equinox.combine
.
Indeed, we could similarly define a function that computes a gradient as
import equinox as eqx
@eqx.filter_grad
def simulate_fn(image_model):
return image_model.simulate()
The issue is we typically want to take the derivative of only particular parameters, whereas this function would compute the gradient with respect to all JAX arrays in the image_model
. To solve this, we need more advanced features in equinox
.
Here we will demonstrating taking the gradient with respect to the pose with a simple L2 loss.
import equinox as eqx
import jax.numpy as jnp
from cryojax.jax_util import get_filter_spec
observed_image = simulate_fn(image_model)
# Split the image model into gradient and non-gradient parameters using
# `eqx.partition` and the cryoJAX `get_filter_spec` utility
where_pose = lambda model: model.structure.pose
filter_spec = get_filter_spec(image_model, where_pose)
model_grad, model_nograd = eqx.partition(image_model, filter_spec)
# Define the gradient function whose first argument is the gradient parameters.
# Use `eqx.combine`` after crossing the gradient boundary
@eqx.filter_jit
@eqx.filter_grad
def gradient_fn(model_grad, model_nograd, observed_image):
image_model = eqx.combine(model_grad, model_nograd)
return jnp.sum((simulate_fn(image_model) - observed_image) ** 2)
# Compute and extract the gradients. We took the gradient with respect to
# a cryoJAX `EulerAnglePose` object
gradients = gradient_fn(model_grad, model_nograd, observed_image)
eqx.tree_pprint(where_pose(gradients), short_arrays=False)
What do eqx.partition
and eqx.combine
do?
JAX transformations typically require grouping pytree leaves into different categories. jax.jit
makes the distinction between traced and static arguments, jax.vmap
has broadcasted and non-broadcasted arguments, and of course jax.grad
has differentiated and non-differentiated. equinox
makes grouping leaves more elegant with the concept of filtering. In this example, we use a cryojax
utility to create what is called a filter_spec
, which is a pytree of booleans of the same structure as image_model
. The values are True
at leaves specified by where_pose
and False
otherwise. Then, eqx.partition
is used to split up the pytree into differentiated and non-differentiated arguments. Finally, after crossing the jax.grad
boundary, eqx.combine
is called to recombine subtrees into a functional image_model
. It is recommended to learn more about filtering; a good example can be found here.
With this basic recipe, we can also compute batches of images with jax.vmap
. We will do this here by simulating images with different random poses.
import jax.random as jr
from cryojax.rotations import SO3
# Define function that generates a partitioned `image_model`
@eqx.filter_jit
@eqx.filter_vmap(
in_axes=(eqx.if_array(0), None, None, None), out_axes=(eqx.if_array(0), None)
)
def make_batched_model(rng_key, potential, config, transfer_theory):
rotation = SO3.sample_uniform(rng_key)
pose = cxs.EulerAnglePose.from_rotation(rotation)
image_model = cxs.make_image_model(
potential, config, pose, transfer_theory, normalizes_signal=True
)
where_pose = lambda model: model.structure.pose
filter_spec = get_filter_spec(image_model, where_pose)
model_vmap, model_novmap = eqx.partition(image_model, filter_spec)
return model_vmap, model_novmap
# Define function that simulates an image from a partitioned `image_model`
@eqx.filter_jit
@eqx.filter_vmap(in_axes=(eqx.if_array(0), None))
def simulate_batch_fn(model_vmap, model_novmap):
image_model = eqx.combine(model_vmap, model_novmap)
return image_model.simulate()
# Generate `image_model` for different RNG keys and simulate a batch of images
n_poses = 9
rng_keys = jr.split(jr.key(seed=0), n_poses)
model_vmap, model_novmap = make_batched_model(
rng_keys, potential, config, transfer_theory
)
images = simulate_batch_fn(model_vmap, model_novmap)
plot_image_stack(images)
Here, we use a eqx.filter_vmap
for model instantiation and image simulation.
Why do we set out_axes=(0, None)
?
When we create a pytree with eqx.filter_vmap
(or jax.vmap
), out_axes
should be a prefix output pytree. If out_axes
is set to None
at a particular leaf, this
says that we do not want to broadcast that leaf (of course, this only works for unmapped leaves). By default jax.vmap
sets out_axes=0
, so all unmapped leaves get broadcasted. Typically, we only want to broadcast leaves in cryojax
functions that are directly mapped, so we must set the out_axes
value at those leaves to None
.