Simulate a batch of images
In this tutorial, we will simulate a naive model of a micrograph. In particular, we will simulate a batch of images of the same particle at random poses, then sum over them.
The goal of this tutorial is to learn how to vmap in cryojax
's recommended pattern. This uses the lightweight wrappers around equinox
in cryojax
.
# Jax imports
import jax
import jax.numpy as jnp
import numpy as np
# Plotting imports and functions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax)
if label is not None:
ax.set(title=label)
return fig, ax
# CryoJAX imports
from jaxtyping import install_import_hook
with install_import_hook("cryojax", "typeguard.typechecked"):
import cryojax.simulator as cxs
from cryojax.io import read_array_with_spacing_from_mrc
from cryojax.rotations import SO3
First, we will build the image formation modeling components that we do not want to vmap over.
# First, load the scattering potential and projection method
filename = "./data/groel_5w0s_scattering_potential.mrc"
real_voxel_grid, voxel_size = read_array_with_spacing_from_mrc(filename)
potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(
real_voxel_grid, voxel_size, pad_scale=2
)
# ... now the projection method
potential_integrator = cxs.FourierSliceExtraction(interpolation_order=1)
# ... and the contrast transfer theory
transfer_theory = cxs.ContrastTransferTheory(
ctf=cxs.ContrastTransferFunction(
defocus_in_angstroms=10000.0,
astigmatism_in_angstroms=0.0,
)
)
# ... finally, the instrument_config
shape = (400, 600)
pixel_size = potential.voxel_size # Angstroms
voltage_in_kilovolts = 300.0
instrument_config = cxs.InstrumentConfig(
shape, pixel_size, voltage_in_kilovolts, pad_scale=1.1
)
image_size = np.asarray(shape) * pixel_size
Now we will construct a ContrastImagingPipeline
by batching over a set of random number generator keys.
from functools import partial
import equinox as eqx
import equinox.internal as eqxi
from jaxtyping import PRNGKeyArray, PyTree
@partial(eqx.filter_vmap, in_axes=(0, None), out_axes=eqxi.if_mapped(axis=0))
def make_imaging_pipeline(
key: PRNGKeyArray, no_vmap: tuple[PyTree, ...]
) -> cxs.ContrastImagingPipeline:
config, potential, potential_integrator = no_vmap
# ... instantiate rotations
rotation = SO3.sample_uniform(key)
# ... now in-plane translation
ny, nx = config.shape
in_plane_offset_in_angstroms = (
jax.random.uniform(key, (2,), minval=-0.45, maxval=0.45)
* jnp.asarray((nx, ny))
* config.pixel_size
)
# ... convert 2D in-plane translation to 3D, setting the out-of-plane translation to
# zero
offset_in_angstroms = jnp.pad(in_plane_offset_in_angstroms, ((0, 1),))
# ... build the pose
pose = cxs.QuaternionPose.from_rotation_and_translation(rotation, offset_in_angstroms)
# ... build the ensemble
structural_ensemble = cxs.SingleStructureEnsemble(potential, pose)
# ... and finally the scattering theory and return
theory = cxs.WeakPhaseScatteringTheory(
structural_ensemble, potential_integrator, transfer_theory
)
return cxs.ContrastImagingPipeline(config, theory)
What's with the out_axes=eqxi.if_mapped(axis=0)
?
When we create a pytree with eqx.filter_vmap
(or jax.vmap
), out_axes
should have the same structure as the output pytree. If out_axes
is set to None
at a particular leaf, this
says that we do not want to broadcast that leaf (of course, this only works for unmapped leaves). By default jax.vmap
sets out_axes=0
, so all unmapped leaves get broadcasted. equinox
allows us to pass out_axes=eqxi.if_mapped(axes=0)
, which specifies not to broadcast pytree leaves unless the leaves are directly mapped.
# Generate RNG keys
number_of_poses = 20
keys = jax.random.split(jax.random.PRNGKey(12345), number_of_poses)
# ... instantiate the instrument_pipeline
imaging_pipeline = make_imaging_pipeline(
keys, (instrument_config, potential, potential_integrator)
)
This may be a little odd at first. We have contructed an imaging_pipeline
, where if we were to directly call its render
method, it would not work. Think of it this way: because we created our imaging_pipeline
s with a vmap
, functions can now only be called after crossing vmap
boundaries. There is very good reason for this! To learn more, read the section of the equinox documentation on model ensembling.
Now that we have a ContrastImagingPipeline
with a batched set of poses, we need some way of telling our vmap
exactly what pytree leaves have batch dimensions. One way equinox
does this is by using pointers to particular pytree leaves to create what is called a filter_spec
.
import cryojax as cx
# ... specify which leaves we would like to vmap over
where = lambda p: p.scattering_theory.structural_ensemble.pose
# ... use a cryojax wrapper to return a filter_spec
filter_spec = cx.get_filter_spec(imaging_pipeline, where)
Here, filter_spec
is a pytree of booleans of the same structure as imaging_pipeline
. The values are True
at leaves that we do want to vmap
over and False
where we don't. Filtered transformations are a cornerstone to equinox
and it is highly recommended to learn about them. See here in the equinox documentation for reading.
Above we have used a cryojax
utility routine for creating a filter_spec
, called cryojax.get_filter_spec
. Next, we will finally define functions to batch and sum over images! To do this, we will again use a cryojax
wrapper to equinox
called filter_vmap_with_spec
. This batches over a pytree, only at leaves specified by filter_spec
.
import equinox as eqx
@partial(cx.filter_vmap_with_spec, filter_spec=filter_spec)
def compute_image_stack(imaging_pipeline):
"""Compute a batch of images at different poses,
specified by the `filter_spec`.
"""
return imaging_pipeline.render()
@eqx.filter_jit
def compute_micrograph(imaging_pipeline):
"""Sum together the image stack."""
return jnp.sum(compute_image_stack(imaging_pipeline), axis=0)
# Compute the image and plot
fig, ax = plt.subplots(figsize=(5.5, 5.5))
micrograph = compute_micrograph(imaging_pipeline)
plot_image(
micrograph,
fig,
ax,
label="Image contrast for a sum of random poses",
interpolation=None,
)
What next?
It is highly recommended to learn about about pytree manipulation in equinox. In particular, read about eqx.partition and eqx.combine. cryojax
has made the choice to hide away some of these details with it's utility routines, but this is meant to just ease users into just using equinox
for themselves.
This thread on the equinox
github may also be useful: https://github.com/patrick-kidger/equinox/issues/618.