Read a particle stack
This tutorial demonstrates how to read in a particle stack in cryojax
. The particle stack read here is given in RELION STAR file format.
After reading the particle stack, it is demonstrated how to compute a power spectrum using cryojax
.
# JAX imports
import equinox as eqx
import jax.numpy as jnp
# Plotting imports and functions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax)
if label is not None:
ax.set(title=label)
return fig, ax
# CryoJAX imports
from cryojax.data import RelionParticleParameterDataset, RelionParticleStackDataset
First, we will read in the RELION particle stack using the cryojax
particle stack loader RelionParticleStackDataset
.
What is a RelionParticleStackDataset
?
CryoJAX implements an abstraction an a dataset in RELION, called a RelionParticleStackDataset
. This object takes in a
RELION STAR file for a particle stack. Upon accessing an image in the particle stack, a ParticleStack
is returned. Specifically, the ParticleStack
stores the image(s) in the image stack, as well as the parameters, here represented as a RelionParticleParameters
container,
in the STAR file. The metadata instantiates compatible cryojax
objects. For example, the RelionParticleParameters
stores a cryojax
models for the contrast transfer function (the CTF
class) and the pose (the EulerAnglePose
class).
More generally, a RelionParticleStackDataset
is an AbstractDataset
.
This abstract interface are part of the cryojax
public API!
# Read in the dataset and plot an image
param_dataset = RelionParticleParameterDataset(
path_to_starfile="./data/ribosome_4ug0_particles.star",
path_to_relion_project="./",
loads_metadata=False, # Set to `True` to load the raw starfile metadata
broadcasts_optics_group=True, # If `True`, this is useful for vmapping
loads_envelope=False, # If `False`, assumes b_factor is 0 and CTF amplitude is 1
)
particle_dataset = RelionParticleStackDataset(param_dataset)
# ... get the zeroth entry in the STAR file
relion_particle = particle_dataset[0]
Upon inspecting the zeroth element of the RelionParticleStackDataset
, we see that a ParticleStack
is returned. We can access a particular image by accessing the ParticleStack.images
. Let's normalize this image, plot it, and make sure the mean and standard deviation are zero and one, respectively.
# Get an image, normalize, and plot it
from cryojax.image import normalize_image
fig, ax = plt.subplots(figsize=(4, 4))
observed_image = normalize_image(relion_particle.images)
plot_image(observed_image, fig, ax)
print(
"Image mean:",
observed_image.mean(),
"Image standard deviation:",
observed_image.std(),
)
This particular image happens to be simulated with cryojax
from the structure of the human 80S ribosome (PDB id 4gu0). In order to simulate the image, a scattering potential was computed and written to a voxel grid using the cisTEM
simulation tool.
We can also use fancy indexing to access multiple particles at once.
# Access multiple images in the stack
relion_particles = particle_dataset[0:3]
print(relion_particles.images.shape)
Now, we see that the images
attribute has a leading dimension for each image. We can also inspect the parameters from the STAR file by printing the CTF
.
# Inspect the CTF
eqx.tree_pprint(relion_particles.parameters.transfer_theory, short_arrays=False)
Notice that all attributes of the CTF
have a leading dimension. For those familiar with RELION STAR format, even the CTF
parameters not stored on a per-particle basis (the opticsGroup) have a leading dimension! This is for convenience working with jax.vmap
transformations.
# Plot multiple images from the particle stack
fig, axes = plt.subplots(figsize=(10, 5), ncols=3)
[plot_image(relion_particles.images[idx], fig, axes[idx]) for idx in range(3)]
plt.tight_layout()
Computing the power spectrum of an image is a common analysis tool in cryo-EM. This can be done in cryojax
!
First, we simply have to compute our image in fourier space and a grid of wave vector magnitudes.
from cryojax.image import rfftn
# Get the particle
relion_particle = particle_dataset[0]
# ... and the image in fourier space
fourier_image = rfftn(relion_particle.images)
# ... and the cartesian coordinate system
pixel_size = relion_particle.parameters.instrument_config.pixel_size
print(relion_particle.images.shape)
frequency_grid_in_angstroms = (
relion_particle.parameters.instrument_config.frequency_grid_in_angstroms
)
# ... now, compute a radial coordinate system
radial_frequency_grid_in_angstroms = jnp.linalg.norm(frequency_grid_in_angstroms, axis=-1)
# ... plot the image in fourier space and the radial frequency grid
fig, axes = plt.subplots(figsize=(5, 4), ncols=2)
plot_image(
jnp.log(jnp.abs(jnp.fft.fftshift(fourier_image, axes=(0,))) ** 2),
fig,
axes[0],
label="Image log power spectrum",
)
plot_image(
jnp.fft.fftshift(radial_frequency_grid_in_angstroms * pixel_size, axes=(0,)),
fig,
axes[1],
label="Radial frequency grid",
)
plt.tight_layout()
We are now ready to compute and plot the radially averaged power spectrum profile! This simply bins the squared fourier amplitudes according to the radial_frequency_grid
.
from cryojax.image import compute_radially_averaged_powerspectrum
fig, ax = plt.subplots(figsize=(4, 4))
n_pixels = relion_particle.parameters.instrument_config.n_pixels
spectrum, frequencies = compute_radially_averaged_powerspectrum(
fourier_image,
radial_frequency_grid_in_angstroms,
pixel_size,
maximum_frequency=1 / (2 * pixel_size),
)
ax.plot(frequencies, spectrum / n_pixels, color="k")
ax.set(
xlabel="frequency magnitude $[\AA^{-1}]$",
ylabel="radially averaged power spectrum",
yscale="log",
)
Here, we see that the Thon rings in our power spectrum are faint since the pixel since of our images is