This tutorial demonstrates how to generate a particle_reader in Relion using cryojax
. This tutorial builds upon tools already shown in other tutorials, so we might skip over some details.
%load_ext autoreload
%autoreload 2
First we will do all the imports required to run the tutorial
# Jax and Equinox imports
from functools import partial
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray
# Plotting imports and functions
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_image(image, fig, ax, cmap="gray", label=None, **kwargs):
im = ax.imshow(image, cmap=cmap, origin="lower", **kwargs)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax)
if label is not None:
ax.set(title=label)
return fig, ax
# CryoJAX imports
import cryojax.simulator as cxs
from cryojax.data import (
RelionParticleDataset,
RelionParticleMetadata,
RelionParticleParameters,
write_simulated_image_stack_from_starfile,
write_starfile_with_particle_parameters,
)
from cryojax.image import operators as op
from cryojax.io import read_atoms_from_pdb
from cryojax.rotations import SO3
Generating a starfile¤
We have split this tutorial in two parts. In the first part we will generate a starfile, but we will not generate any particles. In the second part we will show how to generate particles from a starfile, and how to save such particles to the respective MRC Files.
We have decided to keep this two steps separate for a simple reason, flexibility. We want to make it possible for every user to design their own image formation pipeline. In some of our other tutorials we have showed how to generate noiseless image, images with solvent noise, or images with noise coming from a distribution. You can easily adapt this tutorial to those cases, or any other pipeline you build.
Now we will start by generating a starfile. To do this, we will first create a cryojax RelionParticleMetadata.
In this function we vmap over jax
random keys. You can adapt this function to your needs, such as adapting the range of the distributions for the random parameters, or changing wheter a parameter is random or not.
from cryojax.utils import get_filter_spec
@partial(eqx.filter_vmap, in_axes=(0, None), out_axes=(eqx.if_array(0), None))
def make_particle_parameters(
key: PRNGKeyArray, instrument_config: cxs.InstrumentConfig
) -> tuple[RelionParticleParameters, RelionParticleParameters]:
# Generate random parameters
# Pose
# ... instantiate rotations
key, subkey = jax.random.split(key) # split the key to use for the next random number
rotation = SO3.sample_uniform(subkey)
key, subkey = jax.random.split(key) # do this everytime you use a key!!
# ... now in-plane translation
ny, nx = instrument_config.shape
offset_in_angstroms = (
jax.random.uniform(subkey, (2,), minval=-0.2, maxval=0.2)
* jnp.asarray((nx, ny))
* instrument_config.pixel_size
)
# ... build the pose
pose = cxs.EulerAnglePose.from_rotation_and_translation(rotation, offset_in_angstroms)
# CTF Parameters
# ... defocus
defocus_in_angstroms = jax.random.uniform(subkey, (), minval=1000, maxval=1500)
key, subkey = jax.random.split(key)
astigmatism_in_angstroms = jax.random.uniform(subkey, (), minval=0, maxval=100)
key, subkey = jax.random.split(key)
astigmatism_angle = jax.random.uniform(subkey, (), minval=0, maxval=jnp.pi)
key, subkey = jax.random.split(key)
phase_shift = jax.random.uniform(subkey, (), minval=0, maxval=0)
# no more random numbers needed
# now generate your non-random values
spherical_aberration_in_mm = 2.7
amplitude_contrast_ratio = 0.1
b_factor = 170.0
ctf_scale_factor = 1.0
# ... build the CTF
transfer_theory = cxs.ContrastTransferTheory(
ctf=cxs.ContrastTransferFunction(
defocus_in_angstroms=defocus_in_angstroms,
astigmatism_in_angstroms=astigmatism_in_angstroms,
astigmatism_angle=astigmatism_angle,
spherical_aberration_in_mm=spherical_aberration_in_mm,
amplitude_contrast_ratio=amplitude_contrast_ratio,
phase_shift=phase_shift,
),
envelope=op.FourierGaussian(b_factor=b_factor, amplitude=ctf_scale_factor),
)
relion_particle_parameters = RelionParticleParameters(
instrument_config=instrument_config,
pose=pose,
transfer_theory=transfer_theory,
)
filter_spec = get_filter_spec(
relion_particle_parameters,
lambda x: (
x.instrument_config.voltage_in_kilovolts,
x.instrument_config.pixel_size,
x.transfer_theory.ctf.amplitude_contrast_ratio,
x.transfer_theory.ctf.spherical_aberration_in_mm,
),
inverse=True,
)
return eqx.partition(relion_particle_parameters, filter_spec)
# Generate instrument config
instrument_config = cxs.InstrumentConfig(
shape=(128, 128),
pixel_size=1.5,
voltage_in_kilovolts=300.0,
pad_scale=1.0, # no padding
)
# Generate RNG keys
number_of_images = 100
keys = jax.random.split(jax.random.PRNGKey(0), number_of_images)
# ... instantiate the RelionParticleMetadata
particle_parameters_vmap, particle_parameters_novmap = make_particle_parameters(
keys, instrument_config
)
To see the values of any particular parameter, you can do
Lastly, we can easily generate the starfile from the built-in cryojax.data.write_starfile_with_particle_parameters
function.
The mrc_batch_size
parameter let's you specify how many particles you want per mrcfile.
In this case we will place 50 images per mrcfile, giving us two mrcfiles.
# ... generate the starfile
particle_parameters = eqx.combine(particle_parameters_vmap, particle_parameters_novmap)
write_starfile_with_particle_parameters(
particle_parameters, "./relion_dataset.star", mrc_batch_size=50
)
Simulating particles based on a starfile and writing them to mrcs¤
Now we will see how to define the functions required for our write_simulated_image_stack_from_starfile
function to work.
First, let's define our potential and let's load the starfile
Simulating noiseless images¤
First, we will generate a stack without noise using one of cryojax
imaging pipelines.
# First load the starfile
path_to_mrc_files = "./relion_dataset_particles/noiseless"
metadata = RelionParticleMetadata(
path_to_starfile="./relion_dataset.star", # starfile we created
path_to_relion_project=path_to_mrc_files, # here is where the mrcs will be saved
get_envelope_function=True,
)
For more information on reading data in RELION, check our Read a particle stack
tutorial.
Now let's define the structure we will use to generate images. In this case, we show how to load a pdb and then turn it into a voxel grid. We all need to define how we will integrate our potential.
filename = "./data/groel_chainA.pdb"
atom_positions, atom_identities, b_factors = read_atoms_from_pdb(
filename, assemble=False, get_b_factors=True
)
atomic_potential = cxs.PengAtomicPotential(atom_positions, atom_identities, b_factors)
# get values from the "optics" datablock in the starfile
box_size = metadata[0].instrument_config.shape[0]
voxel_size = metadata[0].instrument_config.pixel_size
real_voxel_grid = atomic_potential.as_real_voxel_grid(
shape=(box_size, box_size, box_size), voxel_size=voxel_size
)
potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(
real_voxel_grid, voxel_size, pad_scale=2
)
potential_integrator = cxs.FourierSliceExtraction(interpolation_order=1)
Now we will build a function that generates an imaging pipeline from a relion_particle_metadata
from typing import Any
def build_imaging_pipeline_from_relion_particle_metadata(
relion_particle_parameters: RelionParticleParameters,
args: Any,
) -> cxs.ContrastImagingPipeline:
potential, potential_integrator = args
structural_ensemble = cxs.SingleStructureEnsemble(
potential, relion_particle_parameters.pose
)
scattering_theory = cxs.WeakPhaseScatteringTheory(
structural_ensemble,
potential_integrator,
relion_particle_parameters.transfer_theory,
)
imaging_pipeline = cxs.ContrastImagingPipeline(
relion_particle_parameters.instrument_config, scattering_theory
)
return imaging_pipeline
and a function that simply computes images from a relion_particle_metadata. They don't really need to be in separate functions.
def compute_image(
relion_particle_parameters: RelionParticleParameters,
args: Any,
):
imaging_pipeline = build_imaging_pipeline_from_relion_particle_metadata(
relion_particle_parameters, args
)
return imaging_pipeline.render()
image = compute_image(metadata[0], (potential, potential_integrator))
plot_image(image, plt.figure(), plt.gca(), label="Simulated Image");
Finally, we can use this function to write a particle stack!
args = (potential, potential_integrator) # the args compute_image takes
write_simulated_image_stack_from_starfile(
metadata=metadata,
compute_image=compute_image,
args=args,
is_jittable=True,
batch_size_per_mrc=10,
seed=None,
overwrite=True,
compression=None,
)
How does write_simulated_image_stack_from_starfile
work?
The arguments to call this function are:
metadata
: needs to be an instance ofRelionParticleMetadata
compute_image
: function that should take (RelionParticleParameters, args) or (jax.PRNGKey, RelionParticleParameters, args) as inputargs
: the args forcompute_image
functionis_jittable
: whether or not your function can be jitted withequinox.filter_jit
. (If you are not sure, run it and set it to False if it does not work)batch_size_per_mrc
: if your function is jittable we will vmap the computation of the images, which might have a big memory usage. If you get some "Ran out of memory" error, then adjust this batch size. Setting it to "None" will assume no batching needs to be done. Setting it to 1 runs the code in serial mode.seed
: a seed for generating noise (Optional)overwrite
: whether or not you want existing files to be overwritten.compression
: mrcfile compression, see themrcfile
library docs for more information.
If you add a seed, then compute_image
should take a jax.random.PRNGKey
as its first argument (next example), which is needed to generate a sample of the noise from whatever distribution you define.
Inside our code we will call your compute_image
function to generate an image stack, and then save it as an mrcfile.
For more information on imaging pipelines check the Simulate an image
tutorial.
Simulating images with a noise distribution¤
In this example we will use a cryojax distribution to add the noise, but in practice you can do it however you want.
We need a different dataset, as we will save these images in a different folder
# First load the starfile
path_to_mrc_files = "./relion_dataset_particles/with_noise"
metadata_for_noise_images = RelionParticleMetadata(
path_to_starfile="./relion_dataset.star", # starfile we created
path_to_relion_project=path_to_mrc_files, # here is where the mrcs will be saved
get_envelope_function=True,
)
the other parameters do not change, so there is no need to initialize them again.
We will follow a simular approach as before. First, let's write a function that generates a distribution from a RelionParticleParameters
:
from cryojax.inference import distributions as dist
def build_distribution_from_relion_particle_parameters(
relion_particle_parameters: RelionParticleParameters,
args: Any,
) -> dist.IndependentGaussianFourierModes:
potential, potential_integrator = args
structural_ensemble = cxs.SingleStructureEnsemble(
potential, relion_particle_parameters.pose
)
scattering_theory = cxs.WeakPhaseScatteringTheory(
structural_ensemble,
potential_integrator,
relion_particle_parameters.transfer_theory,
)
imaging_pipeline = cxs.ContrastImagingPipeline(
relion_particle_parameters.instrument_config, scattering_theory
)
distribution = dist.IndependentGaussianFourierModes(
imaging_pipeline,
signal_scale_factor=jnp.sqrt(instrument_config.n_pixels),
variance_function=op.Constant(1.0),
)
return distribution
and a function that generates an image
def compute_image_with_noise(
key: PRNGKeyArray,
relion_particle_parameters: RelionParticleParameters,
args: Any,
):
distribution = build_distribution_from_relion_particle_parameters(
relion_particle_parameters, args
)
return distribution.sample(key)
and generate an image from such distribution
key = jax.random.PRNGKey(0)
image_with_noise = compute_image_with_noise(
key, metadata_for_noise_images[0], (potential, potential_integrator)
)
plot_image(image_with_noise, plt.figure(), plt.gca(), label="Simulated Image with noise");
Lastly, let's generate a stack with noise
args = (potential, potential_integrator) # the args compute_image takes
write_simulated_image_stack_from_starfile(
metadata_for_noise_images,
compute_image_with_noise,
args,
seed=0,
overwrite=True,
)
Now we can load our images using RelionParticleStackReader
s
particle_reader_noiseless = RelionParticleDataset(metadata)
particle_reader_noisy = RelionParticleDataset(metadata_for_noise_images)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
for a in ax:
a.set_axis_off()
im1 = plot_image(
particle_reader_noiseless[0].image_stack,
fig,
ax[0],
label="image without noise",
)
im2 = plot_image(
particle_reader_noisy[0].image_stack,
fig,
ax[1],
label="image with noise",
)