Simulate a batch of images

In this tutorial, we will simulate a naive model of a micrograph. In particular, we will simulate a batch of images of the same particle at random poses, then sum over them.

The goal of this tutorial is to learn how to vmap in cryojax's recommended pattern. This uses the lightweight wrappers around equinox in cryojax.

# Jax imports
import jax
import jax.numpy as jnp
import numpy as np
# Plotting imports and functions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    if label is not None:
        ax.set(title=label)
    return fig, ax
# CryoJAX imports
from jaxtyping import install_import_hook


with install_import_hook("cryojax", "typeguard.typechecked"):
    import cryojax.simulator as cxs
    from cryojax.io import read_array_with_spacing_from_mrc
    from cryojax.rotations import SO3

First, we will build the image formation modeling components that we do not want to vmap over.

# First, load the scattering potential and projection method
filename = "./data/groel_5w0s_scattering_potential.mrc"
real_voxel_grid, voxel_size = read_array_with_spacing_from_mrc(filename)
potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(
    real_voxel_grid, voxel_size, pad_scale=2
)
# ... now the projection method
potential_integrator = cxs.FourierSliceExtraction(interpolation_order=1)
# ... and the contrast transfer theory
transfer_theory = cxs.ContrastTransferTheory(
    ctf=cxs.ContrastTransferFunction(
        defocus_in_angstroms=10000.0,
        astigmatism_in_angstroms=0.0,
    )
)
# ... finally, the instrument_config
shape = (400, 600)
pixel_size = potential.voxel_size  # Angstroms
voltage_in_kilovolts = 300.0
instrument_config = cxs.InstrumentConfig(
    shape, pixel_size, voltage_in_kilovolts, pad_scale=1.1
)
image_size = np.asarray(shape) * pixel_size

Now we will construct a ContrastImagingPipeline by batching over a set of random number generator keys.

from functools import partial

import equinox as eqx
import equinox.internal as eqxi
from jaxtyping import PRNGKeyArray, PyTree


@partial(eqx.filter_vmap, in_axes=(0, None), out_axes=eqxi.if_mapped(axis=0))
def make_imaging_pipeline(
    key: PRNGKeyArray, no_vmap: tuple[PyTree, ...]
) -> cxs.ContrastImagingPipeline:
    config, potential, potential_integrator = no_vmap
    # ... instantiate rotations
    rotation = SO3.sample_uniform(key)
    # ... now in-plane translation
    ny, nx = config.shape
    in_plane_offset_in_angstroms = (
        jax.random.uniform(key, (2,), minval=-0.45, maxval=0.45)
        * jnp.asarray((nx, ny))
        * config.pixel_size
    )
    # ... convert 2D in-plane translation to 3D, setting the out-of-plane translation to
    # zero
    offset_in_angstroms = jnp.pad(in_plane_offset_in_angstroms, ((0, 1),))
    # ... build the pose
    pose = cxs.QuaternionPose.from_rotation_and_translation(rotation, offset_in_angstroms)
    # ... build the ensemble
    structural_ensemble = cxs.SingleStructureEnsemble(potential, pose)
    # ... and finally the scattering theory and return
    theory = cxs.WeakPhaseScatteringTheory(
        structural_ensemble, potential_integrator, transfer_theory
    )
    return cxs.ContrastImagingPipeline(config, theory)

What's with the out_axes=eqxi.if_mapped(axis=0)?

When we create a pytree with eqx.filter_vmap (or jax.vmap), out_axes should have the same structure as the output pytree. If out_axes is set to None at a particular leaf, this says that we do not want to broadcast that leaf (of course, this only works for unmapped leaves). By default jax.vmap sets out_axes=0, so all unmapped leaves get broadcasted. equinox allows us to pass out_axes=eqxi.if_mapped(axes=0), which specifies not to broadcast pytree leaves unless the leaves are directly mapped.

# Generate RNG keys
number_of_poses = 20
keys = jax.random.split(jax.random.PRNGKey(12345), number_of_poses)

# ... instantiate the instrument_pipeline
imaging_pipeline = make_imaging_pipeline(
    keys, (instrument_config, potential, potential_integrator)
)

This may be a little odd at first. We have contructed an imaging_pipeline, where if we were to directly call its render method, it would not work. Think of it this way: because we created our imaging_pipelines with a vmap, functions can now only be called after crossing vmap boundaries. There is very good reason for this! To learn more, read the section of the equinox documentation on model ensembling.

Now that we have a ContrastImagingPipeline with a batched set of poses, we need some way of telling our vmap exactly what pytree leaves have batch dimensions. One way equinox does this is by using pointers to particular pytree leaves to create what is called a filter_spec.

import cryojax as cx


# ... specify which leaves we would like to vmap over
where = lambda p: p.scattering_theory.structural_ensemble.pose
# ... use a cryojax wrapper to return a filter_spec
filter_spec = cx.get_filter_spec(imaging_pipeline, where)

Here, filter_spec is a pytree of booleans of the same structure as imaging_pipeline. The values are True at leaves that we do want to vmap over and False where we don't. Filtered transformations are a cornerstone to equinox and it is highly recommended to learn about them. See here in the equinox documentation for reading.

Above we have used a cryojax utility routine for creating a filter_spec, called cryojax.get_filter_spec. Next, we will finally define functions to batch and sum over images! To do this, we will again use a cryojax wrapper to equinox called filter_vmap_with_spec. This batches over a pytree, only at leaves specified by filter_spec.

import equinox as eqx


@partial(cx.filter_vmap_with_spec, filter_spec=filter_spec)
def compute_image_stack(imaging_pipeline):
    """Compute a batch of images at different poses,
    specified by the `filter_spec`.
    """
    return imaging_pipeline.render()


@eqx.filter_jit
def compute_micrograph(imaging_pipeline):
    """Sum together the image stack."""
    return jnp.sum(compute_image_stack(imaging_pipeline), axis=0)
# Compute the image and plot
fig, ax = plt.subplots(figsize=(5.5, 5.5))
micrograph = compute_micrograph(imaging_pipeline)
plot_image(
    micrograph,
    fig,
    ax,
    label="Image contrast for a sum of random poses",
    interpolation=None,
)
(<Figure size 550x550 with 2 Axes>,
 <Axes: title={'center': 'Image contrast for a sum of random poses'}>)
No description has been provided for this image

What next?

It is highly recommended to learn about about pytree manipulation in equinox. In particular, read about eqx.partition and eqx.combine. cryojax has made the choice to hide away some of these details with it's utility routines, but this is meant to just ease users into just using equinox for themselves.

This thread on the equinox github may also be useful: https://github.com/patrick-kidger/equinox/issues/618.