Read a particle stack

This tutorial demonstrates how to read in a particle stack in cryojax. The particle stack read here is given in RELION STAR file format.

After reading the particle stack, it is demonstrated how to compute a power spectrum using cryojax.

# JAX imports
import equinox as eqx
import jax.numpy as jnp
# Plotting imports and functions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
    im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    if label is not None:
        ax.set(title=label)
    return fig, ax
# CryoJAX imports
from cryojax.data import RelionParticleParameterDataset, RelionParticleStackDataset

First, we will read in the RELION particle stack using the cryojax particle stack loader RelionParticleStackDataset.

What is a RelionParticleStackDataset?

CryoJAX implements an abstraction an a dataset in RELION, called a RelionParticleStackDataset. This object takes in a RELION STAR file for a particle stack. Upon accessing an image in the particle stack, a ParticleStack is returned. Specifically, the ParticleStack stores the image(s) in the image stack, as well as the parameters, here represented as a RelionParticleParameters container, in the STAR file. The metadata instantiates compatible cryojax objects. For example, the RelionParticleParameters stores a cryojax models for the contrast transfer function (the CTF class) and the pose (the EulerAnglePose class).

More generally, a RelionParticleStackDataset is an AbstractDataset. This abstract interface are part of the cryojax public API!

# Read in the dataset and plot an image
param_dataset = RelionParticleParameterDataset(
    path_to_starfile="./data/ribosome_4ug0_particles.star",
    path_to_relion_project="./",
    loads_metadata=False,  # Set to `True` to load the raw starfile metadata
    broadcasts_optics_group=True,  # If `True`, this is useful for vmapping
    loads_envelope=False,  # If `False`, assumes b_factor is 0 and CTF amplitude is 1
)
particle_dataset = RelionParticleStackDataset(param_dataset)
# ... get the zeroth entry in the STAR file
relion_particle = particle_dataset[0]

Upon inspecting the zeroth element of the RelionParticleStackDataset, we see that a ParticleStack is returned. We can access a particular image by accessing the ParticleStack.images. Let's normalize this image, plot it, and make sure the mean and standard deviation are zero and one, respectively.

# Get an image, normalize, and plot it
from cryojax.image import normalize_image


fig, ax = plt.subplots(figsize=(4, 4))
observed_image = normalize_image(relion_particle.images)
plot_image(observed_image, fig, ax)
print(
    "Image mean:",
    observed_image.mean(),
    "Image standard deviation:",
    observed_image.std(),
)
Image mean: -1.373291e-08 Image standard deviation: 1.0000002

No description has been provided for this image

This particular image happens to be simulated with cryojax from the structure of the human 80S ribosome (PDB id 4gu0). In order to simulate the image, a scattering potential was computed and written to a voxel grid using the cisTEM simulation tool.

We can also use fancy indexing to access multiple particles at once.

# Access multiple images in the stack
relion_particles = particle_dataset[0:3]
print(relion_particles.images.shape)
(3, 100, 100)

Now, we see that the images attribute has a leading dimension for each image. We can also inspect the parameters from the STAR file by printing the CTF.

# Inspect the CTF
eqx.tree_pprint(relion_particles.parameters.transfer_theory, short_arrays=False)
ContrastTransferTheory(
  ctf=AberratedAstigmaticCTF(
    defocus_in_angstroms=Array([10025.484, 10025.484, 10025.484], dtype=float32),
    astigmatism_in_angstroms=Array([50.970703, 50.970703, 50.970703], dtype=float32),
    astigmatism_angle=Array([-54.58706, -54.58706, -54.58706], dtype=float32),
    spherical_aberration_in_mm=Array([2.7, 2.7, 2.7], dtype=float32)
  ),
  envelope=None,
  amplitude_contrast_ratio=Array([0.1, 0.1, 0.1], dtype=float32),
  phase_shift=Array([0., 0., 0.], dtype=float32)
)

Notice that all attributes of the CTF have a leading dimension. For those familiar with RELION STAR format, even the CTF parameters not stored on a per-particle basis (the opticsGroup) have a leading dimension! This is for convenience working with jax.vmap transformations.

# Plot multiple images from the particle stack
fig, axes = plt.subplots(figsize=(10, 5), ncols=3)
[plot_image(relion_particles.images[idx], fig, axes[idx]) for idx in range(3)]
plt.tight_layout()
No description has been provided for this image

Computing the power spectrum of an image is a common analysis tool in cryo-EM. This can be done in cryojax!

First, we simply have to compute our image in fourier space and a grid of wave vector magnitudes.

from cryojax.image import rfftn


# Get the particle
relion_particle = particle_dataset[0]
# ... and the image in fourier space
fourier_image = rfftn(relion_particle.images)
# ... and the cartesian coordinate system
pixel_size = relion_particle.parameters.instrument_config.pixel_size
print(relion_particle.images.shape)
frequency_grid_in_angstroms = (
    relion_particle.parameters.instrument_config.frequency_grid_in_angstroms
)
# ... now, compute a radial coordinate system
radial_frequency_grid_in_angstroms = jnp.linalg.norm(frequency_grid_in_angstroms, axis=-1)
# ... plot the image in fourier space and the radial frequency grid
fig, axes = plt.subplots(figsize=(5, 4), ncols=2)
plot_image(
    jnp.log(jnp.abs(jnp.fft.fftshift(fourier_image, axes=(0,))) ** 2),
    fig,
    axes[0],
    label="Image log power spectrum",
)
plot_image(
    jnp.fft.fftshift(radial_frequency_grid_in_angstroms * pixel_size, axes=(0,)),
    fig,
    axes[1],
    label="Radial frequency grid",
)
plt.tight_layout()
(100, 100)

No description has been provided for this image

We are now ready to compute and plot the radially averaged power spectrum profile! This simply bins the squared fourier amplitudes according to the radial_frequency_grid.

from cryojax.image import compute_radially_averaged_powerspectrum


fig, ax = plt.subplots(figsize=(4, 4))
n_pixels = relion_particle.parameters.instrument_config.n_pixels
spectrum, frequencies = compute_radially_averaged_powerspectrum(
    fourier_image,
    radial_frequency_grid_in_angstroms,
    pixel_size,
    maximum_frequency=1 / (2 * pixel_size),
)
ax.plot(frequencies, spectrum / n_pixels, color="k")
ax.set(
    xlabel="frequency magnitude $[\AA^{-1}]$",
    ylabel="radially averaged power spectrum",
    yscale="log",
)
[Text(0.5, 0, 'frequency magnitude $[\\AA^{-1}]$'),
 Text(0, 0.5, 'radially averaged power spectrum'),
 None]
No description has been provided for this image

Here, we see that the Thon rings in our power spectrum are faint since the pixel since of our images is 4 \AA.