Skip to content

Welcome to cryoJAX!¤

CryoJAX is a library that simulates cryo-electron microscopy (cryo-EM) images in JAX. Its purpose is to provide the tools for building downstream data analysis in external workflows and libraries that leverage the statistical inference and machine learning resources of the JAX scientific computing ecosystem. To achieve this, image simulation in cryoJAX is built for reliability and flexibility: it implements a variety of established models and algorithms as well as a framework for implementing new models and algorithms downstream. If your application uses cryo-EM image simulation and it cannot be built downstream, open a pull request.

This documentation is currently a work-in-progress. Your patience while we get this project properly documented is much appreciated! Feel free to get in touch on github issues if you have any questions, bug reports, or feature requests.

Installation¤

Installing cryojax is simple. To start, I recommend creating a new virtual environment. For example, you could do this with conda.

conda create -n cryojax-env -c conda-forge python=3.11

Note that python>=3.10 is required. After creating a new environment, install JAX with either CPU or GPU support. Then, install cryojax. For the latest stable release, install using pip.

python -m pip install cryojax

To install the latest commit, you can build the repository directly.

git clone https://github.com/mjo22/cryojax
cd cryojax
python -m pip install .

The jax-finufft package is an optional dependency used for non-uniform fast fourier transforms. These are included as an option for computing image projections. In this case, we recommend first following the jax_finufft installation instructions and then installing cryojax.

Simulating an image¤

The following is a basic workflow to simulate an image.

import jax
import jax.numpy as jnp
import cryojax.simulator as cxs
from cryojax.io import read_array_with_spacing_from_mrc

# Instantiate the scattering potential from a voxel grid. See the documentation
# for how to generate voxel grids from a PDB
filename = "example_scattering_potential.mrc"
real_voxel_grid, voxel_size = read_array_from_mrc(filename, loads_spacing=True)
potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(real_voxel_grid, voxel_size)
# The pose. Angles are given in degrees.
pose = cxs.EulerAnglePose(
    offset_x_in_angstroms=5.0,
    offset_y_in_angstroms=-3.0,
    phi_angle=20.0,
    theta_angle=80.0,
    psi_angle=-10.0,
)
# The model for the CTF
ctf = cxs.CTF(
    defocus_in_angstroms=9800.0, astigmatism_in_angstroms=200.0, astigmatism_angle=10.0
)
transfer_theory = cxs.ContrastTransferTheory(ctf, amplitude_contrast_ratio=0.1)
# The image configuration
config = cxs.BasicConfig(shape=(320, 320), pixel_size=voxel_size, voltage_in_kilovolts=300.0)
# Instantiate a cryoJAX `image_model` using the `make_image_model` function
image_model = cxs.make_image_model(potential, config, pose, transfer_theory)
# Simulate an image
image = image_model.simulate(outputs_real_space=True)

For more advanced image simulation examples and to understand the many features in this library, see the documentation.

JAX transformations¤

CryoJAX is built on JAX to make use of JIT-compilation, automatic differentiation, and vectorization for cryo-EM data analysis. JAX implements these operations as function transformations. If you aren't familiar with this concept, see the JAX documentation.

Below are examples of implementing these transformations using equinox, a popular JAX library for PyTorch-like classes that smoothly integrate with JAX functional programming. To learn more about how equinox assists with JAX transformations, see here.

Your first JIT compiled function¤

import equinox as eqx

# Define image simulation function using `equinox.filter_jit`
@eqx.filter_jit
def simulate_fn(image_model):
    """Simulate an image with JIT compilation"""
    return image_model.simulate()

# Simulate an image
image = simulate_fn(image_model)

Computing gradients of a loss function¤

import equinox as eqx
import jax.numpy as jnp
from cryojax.jax_util import get_filter_spec

# Load observed data
observed_image = ...

# Split the `image_model` by differentiated and non-differentiated
# arguments
where_pose = lambda model: model.structure.pose
filter_spec = get_filter_spec(image_model, where_pose)
model_grad, model_nograd = eqx.partition(image_model, filter_spec)

@eqx.filter_jit
@eqx.filter_grad
def gradient_fn(model_grad, model_nograd, observed_image):
    """Compute gradients with respect to parameters specified by
    a `where` function.
    """
    image_model = eqx.combine(model_grad, model_nograd)
    return jnp.sum((image_model.simulate() - observed_image)**2)

# Compute gradients
gradients = gradient_fn(model_grad, model_nograd, observed_image)

Vectorizing image simulation¤

import equinox as eqx
from cryojax.jax_util import get_filter_spec

# Vectorize model instantiation
@eqx.filter_jit
@eqx.filter_vmap(in_axes=(0, None, None, None), out_axes=(eqx.if_array(0), None))
def make_image_model_vmap(wxyz, potential, config, transfer_theory):
    pose = cxs.QuaternionPose(wxyz=wxyz)
    image_model = cxs.make_image_model(
        potential, config, pose, transfer_theory, normalizes_signal=True
    )
    where_pose = lambda model: model.structure.pose
    filter_spec = get_filter_spec(image_model, where_pose)
    model_vmap, model_novmap = eqx.partition(image_model, filter_spec)

    return model_vmap, model_novmap


# Define image simulation function
@eqx.filter_jit
@eqx.filter_vmap(in_axes=(eqx.if_array(0), None))
def simulate_fn_vmap(model_vmap, model_novmap):
    image_model = eqx.combine(model_vmap, model_novmap)
    return image_model.simulate()

# Simulate batch of images
wxyz = ...  # ... load quaternions
model_vmap, model_novmap = make_image_model_vmap(wxyz, potential, config, transfer_theory)
images = simulate_fn_vmap(model_vmap, model_novmap)

Acknowledgements¤

  • cryojax implementations of several models and algorithms, such as the CTF, fourier slice extraction, and electrostatic potential computations has been informed by the open-source cryo-EM software cisTEM.
  • cryojax is built using equinox, a popular JAX library for PyTorch-like classes that smoothly integrate with JAX functional programming. We highly recommend learning about equinox to fully make use of the power of jax.