Module pyms.py_multislice
Module containing functions for core multislice and PRISM algorithms.
Expand source code
"""Module containing functions for core multislice and PRISM algorithms."""
import matplotlib.pyplot as plt
import numpy as np
import torch
from .Probe import wavev, focused_probe
from .utils.numpy_utils import (
ensure_array,
bandwidth_limit_array,
q_space_array,
fourier_interpolate_2d,
crop,
)
from .utils.torch_utils import (
iscomplex,
# roll_n,
cx_from_numpy,
cx_to_numpy,
complex_matmul,
complex_mul,
fourier_shift_array,
amplitude,
get_device,
ensure_torch_array,
crop_to_bandwidth_limit_torch,
size_of_bandwidth_limited_array,
fourier_interpolate_2d_torch,
crop_window_to_flattened_indices_torch,
crop_window_to_periodic_indices,
crop_torch,
)
def tqdm_handler(showProgress):
"""Handle showProgress boolean or string input for the tqdm progress bar."""
if isinstance(showProgress, str):
if showProgress.lower() == "notebook":
from tqdm import tqdm_notebook as tqdm
tdisable = False
elif isinstance(showProgress, bool):
tdisable = not showProgress
from tqdm import tqdm
return tdisable, tqdm
def thickness_to_slices(
thicknesses, slice_thickness, subslicing=False, subslices=[1.0]
):
"""Convert thickness in Angstroms to number of multislice slices."""
t = np.asarray(ensure_array(thicknesses))
if subslicing:
# Work out how many slice of the structure is closest to the desired
# output thicknesses
m = len(subslices)
nslices = (t // slice_thickness).astype(np.int) * m
from scipy.spatial.distance import cdist
# Work out which subslices of the structure
remainder = (t % slice_thickness) / slice_thickness
n = len(remainder)
dist = cdist(
remainder.reshape((n, 1)),
np.concatenate(([0], subslices[:-1])).reshape((m, 1)),
)
z = [0] + (nslices + np.asarray([i for i in np.argmin(dist, axis=1)])).tolist()
return [np.arange(z[i], z[i + 1]) for i in range(len(z) - 1)]
else:
return np.ceil(t / slice_thickness).astype(np.int)
def make_propagators(
gridshape,
gridsize,
eV,
subslices=[1.0],
tilt=[0, 0],
tilt_units="mrad",
bandwidth_limit=2 / 3,
):
"""
Make the Fresnel freespace propagators for a multislice simulation.
Parameters
----------
gridshape : (2,) array_like
Pixel dimensions of the 2D grid
gridsize : (3,) array_like
Size of the grid in real space (first two dimensions) and thickness of
the object (third dimension)
eV : float
Probe energy in electron volts
subslices : array_like, optional
A one dimensional array-like object containing the depths (in fractional
coordinates) at which the object will be subsliced. The last entry
should always be 1.0. For example, to slice the object into four equal
sized slices pass [0.25,0.5,0.75,1.0]
tilt : array_like, optional
Allows the user to simulate a (small < 50 mrad) tilt of the specimen,
by shearing the propagator. Units given by input variable tilt_units.
tilt_units : string, optional
Units of specimen tilt, can be 'mrad','pixels' or 'invA'
Returns
-------
P : (n,Y,X)
Fresnel free-space progators, the first dimension will be of size
len(`gridsize`)
"""
from .Probe import make_contrast_transfer_function
# We will use the make_contrast_transfer_function function to generate
# the propagator, the aperture of this propagator will go out to the maximum
# possible and the function bandwidth_limit_array will provide the band
# width_limiting
gridmax = np.asarray(gridshape) / np.asarray(gridsize[:2]) / 2
app = np.hypot(*gridmax)
# Intitialize array
prop = np.zeros((len(subslices), *gridshape), dtype=np.complex)
for islice, s_ in enumerate(subslices):
if islice == 0:
deltaz = s_ * gridsize[2]
else:
deltaz = (s_ - subslices[islice - 1]) * gridsize[2]
# Calculate propagator
prop[islice, :, :] = bandwidth_limit_array(
make_contrast_transfer_function(
gridshape,
gridsize[:2],
eV,
app,
df=deltaz,
app_units="invA",
optic_axis=tilt,
tilt_units=tilt_units,
),
limit=bandwidth_limit,
)
return prop
def generate_slice_indices(nslices, nsubslices, subslicing=False):
"""Generate the slice indices for the multislice routine."""
from collections.abc import Sequence
if isinstance(nslices, Sequence) or isinstance(nslices, np.ndarray):
# If a list is passed, continue as before
return nslices
else:
# If an integer is passed generate a list of slices to iterate through
niterations = nslices if subslicing else nslices * nsubslices
return np.arange(niterations)
def multislice(
probes,
nslices,
propagators,
transmission_functions,
tiling=[1, 1],
device_type=None,
seed=None,
return_numpy=True,
qspace_in=False,
qspace_out=False,
posn=None,
subslicing=False,
output_to_bandwidth_limit=True,
reverse=False,
transpose=False,
):
"""
Multislice algorithm for scattering of an electron probe.
Parameters
----------
probes : (n,Y,X) or (Y,X) complex array_like
Electron probe wave function(s)
nslices : int, array_like
The number of slices (iterations) to perform multislice over, if an
propagators : (Z,Y,X,2) or (Y,X,2) torch.array
Fresnel free space operators required for the multislice algorithm
used to propagate the scattering matrix
transmission_functions : (Z,nT,Y,X,2)
The transmission functions describing the electron's interaction
with the specimen for the multislice algorithm
tiling : (2,) array_like
Tiling of a repeat unit cell on simulation grid.
device_type : torch.device, optional
torch.device object which will determine which device (CPU or GPU) the
calculations will run on
seed : int, optional
Seed for the random number generator for frozen phonon configurations
return_numpy : bool, optional
Calculations are performed on pytorch tensors for speed, however numpy
arrays are more convenient for processing. This input allows the
user to control how the output is returned
qspace_in : bool, optional
Should be set to True if the input wavefunction is in momentum (q) space
and False otherwise (this is the default)
qspace_out : bool, optional
Should be set to True if the output wavefunction is desired in momentum
(q) space and False otherwise (this is the default)
posn : None, optional
Does nothing, included to match calling signature for STEM function
subslicing : bool, optional
Pass subslicing=True to access propagation to sub-slices of the
unit cell, in this case nslices is taken to be in units of subslices
to propagate rather than unit cells (i.e. nslices = 3 will propagate
1.5 unit cells for a subslicing of 2 subslices per unit cell)
output_to_bandwidth_limit : bool, optional
Bandwidth-limiting of the arrays is used in multislice to stop
wrap-around error in reciprocal space, therefore the output of the
multislice algorithm will be zero beyond some point in reciprocal space
if this is set to True then these array entries will be cropped out.
This does have the effect of the output of the function being on a
different sized grid to the input.
reverse : bool, optional
Run inverse multislice (for back propagation of a wavefunction)
transpose : bool, optional
Reverse the order of the multislice operations, ie. apply propagator
first and then transmission function
Returns
-------
psi : (Y,X) or (n,Y,X) complex torch.tensor or np.ndarray
Exit surface wave functions as a pytorch tensor or numpy array (default)
depending on whether return_numpy is True or False. If the input `probes`
is two dimensional then n = 1
"""
# If a single integer is passed to the routine then Seed random number generator,
# , if None then np.random.RandomState will use the system clock as a seed
seed_provided = not (seed is None)
if not seed_provided:
r = np.random.RandomState()
np.random.seed(seed)
# Initialize device cuda if available, CPU if no cuda is available
device = get_device(device_type)
# Since pytorch doesn't have a complex data type we need to add an extra
# dimension of size 2 to each tensor that will store real and imaginary
# components.
T = ensure_torch_array(transmission_functions, device=device)
P = ensure_torch_array(propagators, dtype=T.dtype, device=device)
psi = ensure_torch_array(probes, dtype=T.dtype, device=device)
nT, nsubslices, nopiy, nopix = T.shape[:4]
# Probe needs to start multislice algorithm in real space
if qspace_in:
psi = torch.ifft(psi, signal_ndim=2)
slices = generate_slice_indices(nslices, nsubslices, subslicing)
for i, islice in enumerate(slices):
# If an array-like object is passed then this will be used to uniquely
# and reproducibly seed each iteration of the multislice algorithm
if seed_provided:
r = np.random.RandomState(seed[islice])
subslice = islice % nsubslices
# Pick random phase grating
it = r.randint(0, nT)
# To save memory in the case of equally sliced sample, there is the option
# of only using one propagator, this statement catches this case.
if P.dim() < 4:
P_ = P
else:
P_ = P[subslice]
# If the transmission function is from a tiled unit cell then
# there is the option of randomly shifting it around to
# generate "more" psuedo-random transmission functions
if tiling[0] == 1 & tiling[1] == 1:
T_ = T[it, subslice]
elif nopiy % tiling[0] == 0 and nopix % tiling[1] == 0:
T_ = T[it, subslice]
if tiling[0] > 1:
# Shift an integer number of pixels in y
T_ = torch.roll(T_, r.randint(0, tiling[0]) * (nopiy // tiling[0]), 0)
if tiling[1] > 1:
# Shift an integer number of pixels in x
T_ = torch.roll(T_, r.randint(1, tiling[1]) * (nopix // tiling[1]), 1)
else:
# Case of a non-integer pixel shifting of the unit cell
yshift = r.randint(0, tiling[0]) * (nopiy / tiling[0])
xshift = r.randint(0, tiling[1]) * (nopix / tiling[1])
shift = torch.tensor([yshift, xshift])
# Generate an array to perform Fourier shift of transmission
# function
FFT_shift_array = fourier_shift_array(
[nopiy, nopix], shift, dtype=T.dtype, device=T.device
)
# Apply Fourier shift theorem for sub-pixel shift
T_ = torch.ifft(
complex_mul(FFT_shift_array, torch.fft(T[it, subslice], signal_ndim=2)),
signal_ndim=2,
)
# Perform multislice iteration
if transpose or reverse:
# Reverse multislice complex conjugates the transmission and
# propagation. Both reverse and transpose multislice reverse
# the order of the transmission and conjugation operations
# probe should start in real space and finish this iteration in
# real space
psi = complex_mul(
torch.ifft(
complex_mul(torch.fft(psi, signal_ndim=2), P_, reverse),
signal_ndim=2,
),
T_,
reverse,
)
else:
# Standard multislice iteration - probe should start in real space
# and finish this iteration in reciprocal space
psi = complex_mul(torch.fft(complex_mul(psi, T_), signal_ndim=2), P_)
# The probe can be cropped to the bandwidth limit, this removes
# superfluous array entries in reciprocal space that are zero
# Since the next inverse FFT will apply a factor equal to the
# square root number of pixels we have to adjust the values
# of the array to compensate
if i == len(slices) - 1:
lim = 2 / 3 if output_to_bandwidth_limit else 1
psi = crop_to_bandwidth_limit_torch(
psi,
qspace_in=not (transpose or reverse),
qspace_out=qspace_out,
limit=lim,
norm="conserve_norm",
)
elif not (transpose or reverse):
# Inverse Fourier transform back to real space for next iteration
psi = torch.ifft(psi, signal_ndim=2)
if len(slices) < 1 and qspace_out:
psi = torch.fft(psi, signal_ndim=2)
if return_numpy:
return cx_to_numpy(psi)
return psi
def STEM_phase_contrast_transfer_function(probe, detector):
"""
Calculate the STEM phase contrast transfer function.
For a thin and weakly scattering sample convolution with the STEM contrast
transfer function gives a good approximate for STEM image contrast.
Parameters
----------
probe : complex, (Y,X) array_like
The STEM probe in reciprocal space
detector : real, (Y,X) array_like
The STEM detector
Returns
-------
PCTF : (Y,X) np.ndarray
The phase contrast transfer function
"""
from .utils import convolve
norm = np.sum(np.square(np.abs(probe)))
# Use two ffts to perform reflection k -> -k
PCTF = (
convolve(
probe,
np.fft.fft2(
np.fft.fft2(np.conj(probe) * detector, norm="ortho"), norm="ortho"
),
)
/ norm
)
PCTF -= np.conj(np.fft.fft2(np.fft.fft2(PCTF, norm="ortho"), norm="ortho"))
return -2 * np.imag(PCTF)
# TODO make detectors binary for use with numpy and pytorch sum routines
def make_detector(gridshape, rsize, eV, betamax, betamin=0, units="mrad"):
"""
Make a STEM detector with acceptance angle between betamin and betamax.
Parameters
----------
gridshape : (2,) array_like
Pixel dimensions of the 2D grid
rsize : (2,) array_like
Size of the grid in real space in units of Angstroms
eV : float
Probe energy in electron volts
betamax : float
Detector outer acceptance semi-angle
betamin : float, optional
Detector inner acceptance semi-angle
units : float, optional
Units of betamin and betamax, mrad or invA are both acceptable
Returns
-------
D : (ndet,Y,X) array_like
The detector functions
"""
# Get reciprocal space array
q = q_space_array(gridshape, rsize)
# If units are mrad convert qspace array from inverse Angstrom to mrad
if units == "mrad":
q /= wavev(eV) / 1000
# Calculate modulus square of reciprocal space array
qsq = np.square(q[0]) + np.square(q[1])
# Make detector
detector = np.logical_and(qsq < betamax ** 2, qsq >= betamin ** 2)
# Convert logical to integer
return np.where(detector, 1, 0)
def nyquist_sampling(rsize=None, resolution_limit=None, eV=None, alpha=None):
"""
Calculate nyquist sampling (typically for minimum sampling of a STEM probe).
If array size in units of length is passed then return how many probe
positions are required otherwise just return the sampling. Alternatively
pass probe accelerating voltage (eV) in electron-volts and probe forming
aperture (alpha) in mrad and the resolution limit in inverse length will be
calculated for you.
"""
if eV is None and alpha is None:
step_size = 1 / (4 * resolution_limit)
elif resolution_limit is None:
step_size = 1 / (4 * wavev(eV) * alpha * 1e-3)
else:
return None
if rsize is None:
return step_size
else:
return np.ceil(rsize / step_size).astype(np.int)
def generate_STEM_raster(
rsize,
eV,
alpha,
tiling=[1, 1],
ROI=[0.0, 0.0, 1.0, 1.0],
gridshape=[1, 1],
invA=False,
):
"""
Return the probe positions for a nyquist-sampled STEM raster.
For a real space size rsize return probe positions in units of fraction of
the array for nyquist sampled STEM raster
Parameters
----------
rsize : (2,) array_like
Size of the grid in real space in units of Angstroms
eV : float
Probe energy in electron volts
alpha : float
Probe forming aperture semi-angle in mrad or inverse Angstorm
(if invA == True)
gridshape : (2,) array_like, optional
Pixel dimensions of the 2D grid, by default [1,1] so probe positions will
be returned as a fraction of the array size
tiling : (2,) array_like
Tiling of a repeat unit cell on simulation grid, if provided STEM raster
will only scan a single unit cell.
ROI : (4,) array_like
Fraction of the unit cell to be scanned. Should contain [y0,x0,y1,x1]
where [x0,y0] and [x1,y1] are the bottom left and top right coordinates
of the region of interest (ROI) expressed as a fraction of the total
grid (or unit cell).
invA : bool
If True, alpha is taken to be in units of inverse Angstrom, not mrad.
This also means that the value of eV no longer matters
Returns
-------
probe_posns : (nY,nX,2) np.ndarray
The probe positions in fractions of the array if gridshape is [1,1] and
in pixel units if gridshape is the size of the pixel array.
"""
# Field of view in Angstrom
FOV = np.asarray([rsize[0] * (ROI[2] - ROI[0]), rsize[1] * (ROI[3] - ROI[1])])
if invA:
# Number of scan coordinates in each dimension
nscan = nyquist_sampling(FOV / np.asarray(tiling), resolution_limit=alpha)
else:
# Number of scan coordinates in each dimension
nscan = nyquist_sampling(FOV / np.asarray(tiling), eV=eV, alpha=alpha)
# Generate Y and X scan coordinates
yy, xx = [
np.arange(
ROI[0 + i] * gridshape[i] / tiling[i],
ROI[2 + i] * gridshape[i] / tiling[i],
step=np.diff(ROI[i::2])[0] * gridshape[i] / nscan[i] / tiling[i],
)[: nscan[i]]
/ gridshape[i]
for i in range(2)
]
return np.stack(np.broadcast_arrays(yy[:, None], xx[None, :]), axis=2)
def workout_4DSTEM_datacube_DP_size(FourD_STEM, rsize, gridshape):
"""
Calculate 4D-STEM datacube diffraction pattern gridsize and resampling function.
Parameters
----------
fourD_STEM : bool or array_like
Pass fourD_STEM = True gives 4D STEM output with native simulation grid
sampling. Alternatively, to save disk space a tuple containing pixel
size and diffraction space extent of the datacube can be passed in. For
example ([64,64],[1.2,1.2]) will output diffraction patterns measuring
64 x 64 pixels and 1.2 x 1.2 inverse Angstroms.
rsize : (2,) array_like
Real space size of simulation grid
gridshape : (2,) array_like
Pixel size of gridshape
Returns
-------
gridout : (2,) array_like
Pixel size of the diffraciton pattern output
resize : function
A function that takes diffraction patterns from the simulation and
resamples and crops them to the requested size.
"""
# Check whether a resampling directive has been given
if isinstance(FourD_STEM, (list, tuple)):
gridout = FourD_STEM[0]
if len(FourD_STEM) > 1:
# Get output grid and diffraction space size of that grid from tuple
Ksize = FourD_STEM[1]
#
diff_pat_crop = np.round(np.asarray(Ksize) * np.asarray(rsize[:2])).astype(
np.int
)
# Define resampling function to crop and interpolate
# diffraction patterns
def resize(array):
cropped = crop(np.fft.fftshift(array, axes=(-1, -2)), diff_pat_crop)
return fourier_interpolate_2d(cropped, gridout, norm="conserve_L1")
else:
# The size in inverse Angstrom of the grid
Ksize = np.asarray(gridout) / np.asarray(rsize)
# Define resampling function to just crop diffraction
# patterns
def resize(array):
return crop(np.fft.fftshift(array, axes=(-1, -2)), gridout)
else:
# If no resampling then the output size is just the simulation
# grid size
gridout = size_of_bandwidth_limited_array(gridshape)
# The size in inverse Angstrom of the grid
Ksize = np.asarray(gridout) / np.asarray(rsize)
# Define a resampling function that does nothing
def resize(array):
return crop(np.fft.fftshift(array, axes=(-1, -2)), gridout)
return gridout, resize, Ksize
def second_moment(array):
"""Calculate the second moment of 2D array as a fraction of array size."""
grids = [np.fft.fftfreq(x) for x in array.shape]
mass = np.sum(array)
first_moment = [
np.sum(x) / mass for x in [grids[0][:, None] * array, grids[1][None, :] * array]
]
y2 = ((grids[0] - first_moment[0] + 0.5) % 1.0 - 0.5) ** 2
x2 = ((grids[1] - first_moment[1] + 0.5) % 1.0 - 0.5) ** 2
grid = y2[:, None] + x2[None, :]
return np.sqrt(np.sum(grid * array) / mass)
def generate_probe_spread_plot(
gridshape,
structure,
eV,
app,
thickness,
subslices=[1],
tiling=[1, 1],
showcrossection=True,
df=0,
probe_posn=[0, 0],
show=True,
device=None,
P=None,
T=None,
nslices=None,
):
"""
Generate probe spread plot to assist with selection of appropriate multislice grid.
A multislice calculation assumes periodic boundary conditions. To avoid
artefacts associated with this the pixel grid must be chosen to have
sufficient size so that the probe does not artificially interfere with
itself through the periodic boundary (wrap around error). The grid sampling
must also be sufficient that electrons scattered to high angles are not
scattered beyond the band-width limit of the array.
The probe spread plot helps identify whenever these two events are happening.
If the probe intensity drops below 0.95 (as a fraction of initial intensity)
then the grid is not sampled finely enough, the pixel size of the array
(gridshape) needs to increased for finer sampling of the specimen potential.
If the probe spread exceeds 0.2 (as a fraction of the array) then too much
of the probe is spreading to the edges of the array, the real space size
of the array (usually controlled by the tiling of the unit cell) needs to
be increased.
Parameters
----------
gridshape : (2,) array_like
Pixel dimensions of the 2D grid
structure : pyms.structure_routines.structure
The structure of interest
eV : float
Probe energy in electron volts
app : float
Probe-forming aperture in mrad
thickness : float
The maximum thickness of the simulation object in Angstrom
subslices : array_like, optional
A one dimensional array-like object containing the depths (in fractional
coordinates) at which the object will be subsliced. The last entry
should always be 1.0. For example, to slice the object into four equal
sized slices pass [0.25,0.5,0.75,1.0]
tiling : (2,) array_like, optional
Tiling of a repeat unit cell on simulation grid
showcrossection : bool
Pass True to plot the projected cross section of the probe to inspect
the spread.
df : float
Probe defocus in Angstrom
probe_posn : array_like, optional
Probe position as a fraction of the unit-cell
P : (n,Y,X) array_like, optional
Precomputed Fresnel free-space propagators
T : (n,Y,X) array_like
Precomputed transmission functions
Returns
-------
fig : matplotlib.figure object
The figure on which the probe spread is plotted
"""
# Calculate multislice propagator and transmission functions
from .Premixed_routines import multislice_precursor
if P is None or T is None:
P, T = multislice_precursor(
structure,
gridshape,
eV,
subslices=subslices,
tiling=tiling,
device=device,
nT=1,
showProgress=False,
)
# Calculate focused STEM probe
probe = focused_probe(
gridshape, structure.unitcell[:2] * np.asarray(tiling), eV, app, df=df
)
pos = np.asarray(probe_posn) / np.asarray(tiling)
from .utils import fourier_shift, fourier_interpolate_2d
probe = fourier_shift(probe, pos, pixel_units=False)
ncols = 1 + showcrossection
fig, ax = plt.subplots(ncols=ncols, figsize=(ncols * 4, 4), squeeze=False)
# Total number of slices (not including subslicing of structure)
if nslices is None:
nslices = int(np.ceil(thickness / structure.unitcell[2]))
# Total number of slices (including subslicing of structure)
maxslices = nslices * len(subslices)
variances = np.zeros(maxslices)
intensity = np.zeros(maxslices)
crossection = np.zeros((maxslices, gridshape[0]))
# Array must be shifted to center probe position
shift = (pos * np.asarray(gridshape)).astype(np.int)
for i in range(maxslices):
probe = multislice(
probe,
[1],
P,
T,
tiling=tiling,
subslicing=True,
output_to_bandwidth_limit=False,
device_type=device,
)
mod = np.roll(np.abs(probe) ** 2, shift=-shift, axis=(-2, -1))
# Record probe intensity and spread
intensity[i] = np.sum(mod)
variances[i] = second_moment(mod)
if showcrossection:
crossection[i] = np.sum(mod, axis=-1)
thicknesses = structure.unitcell[2] * (
np.broadcast_to(np.arange(nslices)[:, None], (nslices, len(subslices))).ravel()
+ np.tile(subslices, nslices)
)
ax[0, 0].set_xlim([0, thicknesses[-1]])
ax[0, 0].set_ylim([0, 1.1])
ax[0, 0].set_ylabel(
"$\\sqrt{\\int \\Psi^2 dx}$", color="red"
) # we already handled the x-label with ax1
ax[0, 0].set_xlabel(r"Depth of propagation ($\AA$)")
ax[0, 0].tick_params(axis="y", labelcolor="red")
ax[0, 0].set_title("Probe intensity and spread")
ax2 = ax[0, 0].twinx() # instantiate a second axes that shares the same x-axis
print(thicknesses.shape, variances.shape)
ax2.plot(thicknesses, variances, "b-")
ax2.tick_params(axis="y", labelcolor="b")
ax2.plot([0, thickness], [0.2, 0.2], "b--")
ax2.set_ylim([0, 0.5])
ax2.set_ylabel("$\\sqrt{\\int \\Psi^2 x^2 dx}$", color="blue")
ax[0, 0].plot(thicknesses, intensity, "r-")
ax[0, 0].plot([0, thicknesses[-1]], [0.95, 0.95], "r--")
nz, ny = crossection.shape
if showcrossection:
ax[0, 1].imshow(
fourier_interpolate_2d(
np.fft.fftshift(np.sqrt(crossection), axes=1), [ny, ny]
),
extent=[0, gridshape[0], thickness, 0],
cmap=plt.get_cmap("gnuplot"),
)
ax[0, 1].set_ylabel(r"Depth of propagation ($\AA$)")
ax[0, 1].set_title("Probe depth cross-section")
fig.tight_layout()
if show:
plt.show(block=True)
return fig
def STEM(
rsize,
probe,
method,
nslices,
eV,
alpha,
batch_size=1,
detectors=None,
FourD_STEM=False,
datacube=None,
PACBED=False,
scan_posn=None,
dtype=torch.float32,
device=None,
tiling=[1, 1],
seed=None,
showProgress=True,
method_args=(),
method_kwargs={},
STEM_image=None,
):
"""
Perform a scanning transmission electron microscopy (STEM) image simulation.
Will return an array containing conventional STEM images and/or a 4D-STEM
datacube depending on inputs
Parameters
----------
rsize : (2,) array_like
The real space size of the grid in Angstroms
probe : (Y,X) array_like
The probe that will be rastered over the object
method : function
A function that takes a probe and propagates it to the exit surface of
the specimen
nslices : int, array_like
The number of slices to perform multislice over
eV : float
Accelerating voltage of the probe, needed to work out probe sampling
requirements
alpha : float
The convergence angle of the probe in mrad, needed to work out probe
sampling requirements
batch_size : int, optional
The multislice algorithm can be performed on multiple probes columns
at once to parrallelize computation, the number of parrallel computations
is set by batch_size.
detectors : (Ndet, Y, X) array_like, optional
Diffraction plane detectors to perform conventional STEM imaging. If
None is passed then no conventional STEM images will be returned.
fourD_STEM : bool or array_like, optional
Pass fourD_STEM = True to perform 4D-STEM simulations. To save disk
space a tuple containing pixel size and diffraction space extent of the
datacube can be passed in. For example ([64,64],[1.2,1.2]) will output
diffraction patterns measuring 64 x 64 pixels and 1.2 x 1.2 inverse
Angstroms.
datacube : (Ny, Nx, Y, X) array_like, optional
datacube for 4D-STEM output, if None is passed (default) this will be
initialized in the function. If a datacube is passed then the result
will be added by the STEM routine (useful for multiple frozen phonon
iterations)
PACBED : bool
If True the STEM function will calculate a position averaged convergent
electron diffraction (PABCED) pattern by averaging the diffraction space
scan_posn : (...,2) array_like, optional
Array containing the STEM scan positions in fractional coordinates.
If provided scan_posn.shape[:-1] will give the shape of the STEM image.
result over all scan positions
dtype : torch.dtype, optional
Datatype of the simulation arrays, by default 32-bit floating point
device : torch.device, optional
torch.device object which will determine which device (CPU or GPU) the
calculations will run on
tiling : (2,) array_like, optional
Tiling of a repeat unit cell on simulation grid, STEM raster will only
scan a single unit cell.
seed : array_like or int, optional
Seed for the random number generator for frozen phonon configurations
showProgress : str or bool, optional
Pass False to disable progress readout, pass 'notebook' to get correct
progress bar behaviour inside a jupyter notebook
method_args : list, optional
Arguments for the method function used to propagate probes to the exit
surface
method_kwargs : Dict, optional
Keyword arguments for the method function used to propagate probes to
the exit surface
STEM_image : (Ndet,Ny,Nx) array_like, optional
Array that will contain the conventional STEM images, if not passed
will be initialized within the function. If it is passed then the result
will be accumulated within the function, which is useful for multiple
frozen phonon iterations.
Returns
-------
Result : dict
A dictionary with the keys "STEM images", "datacube" and "PACBED" which
contain the conventional STEM images, the 4D-STEM datacube and the PACBED
pattern respectively. If any of these simulations where not performed
then the relevant entry will just contain None
"""
from .utils.torch_utils import detect
tdisable, tqdm = tqdm_handler(showProgress)
# Get number of thicknesses in the series
nthick = len(nslices)
if isinstance(nslices[0], int):
nslices_ = np.diff(nslices, prepend=0)
else:
nslices_ = nslices
if device is None:
device = get_device(device)
# Get shape of grid
if torch.is_tensor(probe):
gridshape = probe.shape[-3:-1]
else:
gridshape = probe.shape[-2:]
# Generate scan positions in units of pixels if not supplied
if scan_posn is None:
scan_posn = generate_STEM_raster(rsize[:2], eV, alpha, tiling)
# Number of scan positions
scan_shape = scan_posn.shape[:-1]
nscantot = np.prod(scan_shape)
scan_posn = scan_posn.reshape((nscantot, 2))
# Ensure scan_posn is a pytorch tensor with same device and datatype as other
# arrays
scan_posn = torch.as_tensor(scan_posn).to(device).type(dtype)
# Assume real space probe is passed in so perform Fourier transform in
# anticipation of application of Fourier shift theorem
probe_ = ensure_torch_array(probe, dtype=dtype, device=device)
probe_ = torch.fft(probe_, signal_ndim=2)
# Work out whether to perform conventional STEM or not
conventional_STEM = detectors is not None
if conventional_STEM:
# Get number of detectors
ndet = detectors.shape[0]
# Initialize array in which to store resulting STEM images
if STEM_image is None:
STEM_image = np.zeros((ndet, nthick, nscantot))
else:
STEM_image = STEM_image.reshape((ndet, nthick, nscantot))
# Also move detectors to pytorch if necessary
D = ensure_torch_array(detectors, device=device, dtype=dtype)
else:
STEM_image = None
# Initialize array in which to store resulting 4D-STEM datacube if required
if FourD_STEM:
# Get diffraction pattern gridsize in pixels from input and function
# to resample the simulation output to store in the datacube
gridout, resize, _ = workout_4DSTEM_datacube_DP_size(
FourD_STEM, rsize, gridshape
)
# Check whether a datacube is already provided or not
if datacube is None:
datacube = np.zeros((nthick, *scan_shape, *gridout))
if PACBED:
PACBED_pattern = torch.zeros((nthick, *gridshape), device=device)
else:
PACBED_pattern = None
# This algorithm allows for "batches" of probe to be sent through the
# multislice algorithm to achieve some speed up at the cost of storing more
# probes in memory
if seed is None and batch_size > 1:
# If no seed passed to random number generator then make one to pass to
# the multislice algorithm. This ensure that each probe sees the same
# frozen phonon configuration if we are doing batched multislice
# calculations
seed = np.random.randint(0, 2 ** 31 - 1)
for i in tqdm(
range(int(np.ceil(nscantot / batch_size))),
disable=tdisable,
desc="Probe positions",
):
# Make shifted probes
scan_index = np.arange(
i * batch_size, min((i + 1) * batch_size, nscantot), dtype=np.int
)
# The shift operator array array will be of size batch_size x Y x X
probes = fourier_shift_array(
gridshape,
torch.as_tensor(scan_posn[scan_index]),
dtype=dtype,
device=device,
units="fractional",
)
# Apply shift to original probe
probes = complex_mul(probe_.view(1, *probe_.size()), probes)
# Thickness series
# - need to take the difference between sequential thickness variations
for it, t in enumerate(nslices_):
# Evaluate exit surface wave function from input probes
probes = method(
probes, t, *method_args, posn=scan_posn[scan_index], **method_kwargs
)
# Calculate amplitude of probes, a real output is assumed to be the
# amplitude of the exit surface wave function. Also correct
# normalization to be in units of fractional intensity
cmplxout = iscomplex(probes)
if cmplxout:
amp = amplitude(probes) / np.prod(probes.size()[-3:-1])
else:
amp = probes / np.prod(probes.size()[-2:])
# Calculate STEM images
if conventional_STEM:
# broadcast detector and probe arrays to
# ndet x batch_size x Y x X and reduce final two dimensions
STEM_image[:ndet, it, scan_index] += detect(D, amp).cpu().numpy()
# Store datacube
if FourD_STEM:
DPS = resize(amp.cpu().numpy())
ind = np.unravel_index(scan_index, scan_shape)
for idp, DP in enumerate(DPS):
datacube[it][ind[0][idp], ind[1][idp]] += DP
if PACBED:
PACBED_pattern[it] += torch.sum(amp, axis=0) / nscantot
# In some cases the amplitude will be returned by the function
# in this case multiple thickness values should not be used! This
# break command helps prevent misuse.
if not cmplxout:
break
if conventional_STEM:
STEM_image = np.squeeze(STEM_image.reshape(ndet, nthick, *scan_shape))
if PACBED:
PACBED_pattern = np.fft.fftshift(PACBED_pattern.cpu().numpy(), axes=(-2, -1))
return {"STEM images": STEM_image, "datacube": datacube, "PACBED": PACBED_pattern}
def unit_cell_shift(array, axis, shift, tiles):
"""
Shift an array an integer number of unit cell.
For an array consisting of a number of repeat units given by tiles
shift than array an integer number of unit cells.
"""
indices = torch.remainder(torch.arange(array.shape[-3 + axis]) - shift)
if axis == 0:
return array[indices, :, :]
if axis == 1:
return array[:, indices, :]
def max_grid_resolution(gridshape, rsize, bandwidthlimit=2 / 3, eV=None):
"""
For a given pixel sampling, return maximum multislice grid resolution.
For a given grid pixel size (gridshape) and real space size (rsize) return
maximum resolution permitted by the multislice grid. If the probe
accelerating voltage is passed in as eV resolution will be given in units
of mrad, otherwise resolution will be given in units of inverse Angstrom.
"""
max_res = min([gridshape[x] / rsize[x] / 2 * bandwidthlimit for x in range(2)])
if eV is None:
return max_res
return max_res / wavev(eV) * 1e3
class scattering_matrix:
"""Scattering matrix object for calculations using the PRISM algorithm."""
def __init__(
self,
rsize,
propagators,
transmission_functions,
nslice,
eV,
alpha,
GPU_streaming=False,
batch_size=30,
device=None,
PRISM_factor=[1, 1],
tiling=[1, 1],
device_type=None,
seed=None,
showProgress=True,
bandwidth_limit=2 / 3,
Fourier_space_output=False,
subslicing=False,
transposed=False,
stored_gridshape=None,
):
"""
Initialize with a set of propagators and transmission functions.
Parameters
----------
rsize : (2,) array_like
Real space size of the simulation grid in Angstrom
propagators : (N,Y,X,2) torch.array
Fresnel free space operators required for the multislice algorithm
used to propagate the scattering matrix
transmission_functions : (N,Y,X,2)
The transmission functions describing the electron's interaction
with the specimen for the multislice algorithm
nslice : int
The number of slices of the specimen to propagate the scattering
matrix to
eV : float
Electron probe energy in electron-volts
alpha : float
Maximum input angle for the scattering matrix, should match the
probe forming aperture used in experiment
GPU_streaming : bool, optional
If True, the scattering matrix will be stored off GPU RAM and
streamed to GPU RAM as necessary, does nothing if the calculation
is CPU only
batch_size : int, optional
The multislice algorithm can be performed on multiple scattering
matrix columns at once to parrallelize computation, this number is
set by batch_size.
device : torch.device, optional
torch.device object which will determine which device (CPU or GPU)
the calculations will run on. By default this will be determined
by what device the transmission functions are stored on.
PRISM_factor : int (2,) array_like
The PRISM "interpolation factor" this is the amount by which the
scattering matrices are cropped in real space to speed up
calculations see Ophus, Colin. "A fast image simulation algorithm
for scanning transmission electron microscopy." Advanced structural
and chemical imaging 3.1 (2017): 13 for details on this.
seed : int32, optional
A seed to control seeding of the frozen phonon approximation
showProgress : str or bool, optional
Pass False to disable progress readout, pass 'notebook' to get correct
progress bar behaviour inside a jupyter notebook
bandwidth_limit : float, optional
Band-width limiting of the transmission function and propagators to
prevent wrap-around error in the multislice algorithm, 2/3 by
default
Fourier_space_output : bool, optional
If True the scattering matrix output will be stored in reciprocal
space, default is False
subslicing : bool, optional
Pass subslicing=True to access propagation to sub-slices of the
unit cell, in this case nslices is taken to be in units of subslices
to propagate rather than unit cells (i.e. nslices = 3 will propagate
1.5 unit cells for a subslicing of 2 subslices per unit cell)
transposed : bool, optional
Make a "transposed" scattering matrix - see Brown et al. (2019)
Physical Review Research paper for a discussion of this and its
applications
stored_gridshape : (2,) array_like
Size of the stored grid, can be chosen to be smaller than the
multislice grid to speed up computation of a smaller diffraction
space view than that implied by the multislice at no cost to
computational accuracy.
"""
# Get size of grid
gridshape = transmission_functions.shape[-3:-1]
# Datatype (precision) is inferred from transmission functions
self.dtype = transmission_functions.dtype
# Device (CPU or GPU) is also inferred from transmission functions
self.device = device
if GPU_streaming:
self.device = torch.device("cpu")
elif self.device is None:
self.device = transmission_functions.device
# Get alpha in units of inverse Angstrom
self.alpha_ = wavev(eV) * alpha * 1e-3
self.PRISM_factor = PRISM_factor
self.doPRISM = np.any(np.asarray(PRISM_factor) > 1)
# Make a list of beams in the scattering matrix
# Take beams inside the aperture and every nth beam where n is the
# PRISM "interpolation" factor
q = q_space_array(gridshape, rsize)
inside_aperture = np.less_equal(q[0] ** 2 + q[1] ** 2, self.alpha_ ** 2)
mody, modx = [
np.mod(np.fft.fftfreq(x, 1 / x).astype(np.int), p) == 0
for x, p in zip(gridshape, self.PRISM_factor)
]
self.beams = np.nonzero(
np.logical_and(
np.logical_and(inside_aperture, mody[:, None]), modx[None, :]
)
)
self.beams = [(x + y // 2) % y - y // 2 for x, y in zip(self.beams, gridshape)]
self.nbeams = len(self.beams[0])
# For a scattering matrix stored in real space there is the option
# of storing it on a much smaller pixel grid than the grid used for
# multislice. This is handy when, for example, a large grid
# is required for a converged multislice calculation but only
# the bright-field region of diffraction (small angle region)
# is of interest. Be careful using this in conjunction with
# multiple calls of the propagation method for the scattering matrix,
# as information outside the angular range of the stored grid is lost.
self.crop_output = not (stored_gridshape is None)
if self.crop_output:
self.stored_gridshape = stored_gridshape
# We will only store output of the scattering matrix up to the band
# width limit of the calculation, since this is a circular band-width
# limit on a square grid we have to get somewhat fancy and store a mapping
# of the pixels within the bandwidth limit to a one-dimensional vector
self.bw_mapping = np.argwhere(
np.logical_and(
(
np.abs(np.fft.fftfreq(gridshape[0], d=1 / gridshape[0]))
< self.stored_gridshape[0] // 2
)[:, np.newaxis],
(
np.abs(np.fft.fftfreq(gridshape[1], d=1 / gridshape[1]))
< self.stored_gridshape[1] // 2
)[np.newaxis, :],
)
)
else:
self.stored_gridshape = size_of_bandwidth_limited_array(
transmission_functions.shape[-3:-1]
)
# We will only store output of the scattering matrix up to the band
# width limit of the calculation, since this is a circular band-width
# limit on a square grid we have to get somewhat fancy and store a mapping
# of the pixels within the bandwidth limit to a one-dimensional vector
self.bw_mapping = np.argwhere(
(np.fft.fftfreq(gridshape[0]) ** 2)[:, np.newaxis]
+ (np.fft.fftfreq(gridshape[1]) ** 2)[np.newaxis, :]
< (bandwidth_limit / 2) ** 2
)
self.nbout = self.bw_mapping.shape[0]
self.gridshape, self.rsize, self.eV = [np.asarray(gridshape), rsize, eV]
self.bw_mapping = (
self.bw_mapping + self.gridshape // 2
) % self.gridshape - self.gridshape // 2
self.PRISM_factor, self.tiling = [PRISM_factor, tiling]
self.doPRISM = np.any([self.PRISM_factor[i] > 1 for i in [0, 1]])
self.Fourier_space_output = Fourier_space_output
self.nsubslices = transmission_functions.shape[1]
slices = generate_slice_indices(nslice, self.nsubslices, subslicing=subslicing)
self.GPU_streaming = GPU_streaming
self.transposed = transposed
self.seed = seed
if self.seed is None:
# If no seed passed to random number generator then make one to pass to
# the multislice algorithm. This ensure that each column in the scattering
# matrix sees the same frozen phonon configuration
self.seed = np.random.randint(
0, 2 ** 31 - 1, size=len(slices), dtype=np.uint32
)
# This switch tells the propagate function to initialize the Smatrix
# to plane waves
self.initialized = False
# Propagate wave functions of scattering matrix
self.current_slice = 0
self.show_Progress = showProgress
self.Propagate(
nslice,
propagators,
transmission_functions,
subslicing=subslicing,
showProgress=self.show_Progress,
batch_size=batch_size,
)
def Propagate(
self,
nslice,
propagators,
transmission_functions,
subslicing=False,
batch_size=3,
showProgress=True,
transpose=False,
):
"""
Propagate a scattering matrix to slice nslice of the specimen.
Parameters
----------
nslice : int
The slice in the specimen to propagate the scattering matrix to
propagators : (N,Y,X,2) torch.array
Fresnel free space operators required for the multislice algorithm
used to propagate the scattering matrix
transmission_functions : (N,Y,X,2)
The transmission functions describing the electron's interaction
with the specimen for the multislice algorithm
batch_size : int, optional
The multislice algorithm can be performed on multiple scattering
matrix columns at once to parrallelize computation, this number is
set by batch_size.
subslicing : bool, optional
Pass subslicing=True to access propagation to sub-slices of the
unit cell, in this case nslices is taken to be in units of subslices
to propagate rather than unit cells (i.e. nslices = 3 will propagate
1.5 unit cells for a subslicing of 2 subslices per unit cell)
showProgress : str or bool, optional
Pass False to disable progress readout, pass 'notebook' to get correct
progress bar behaviour inside a jupyter notebook
transpose : bool, optional
Make a "transposed" scattering matrix - see Brown et al. (2019)
Physical Review Research paper for a discussion of this and its
applications
"""
tdisable, tqdm = tqdm_handler(showProgress)
from .Probe import plane_wave_illumination
# Initialize scattering matrix if necessary
if not self.initialized:
if self.Fourier_space_output:
self.S = torch.zeros(
self.nbeams, self.nbout, 2, dtype=self.dtype, device=self.device
)
else:
self.S = torch.zeros(
self.nbeams,
*self.stored_gridshape,
2,
dtype=self.dtype,
device=self.device
)
for ibeam in range(self.nbeams):
# Initialize S-matrix to plane-waves
psi = cx_from_numpy(
plane_wave_illumination(
self.gridshape,
self.rsize[:2],
self.eV,
tilt=[self.beams[0][ibeam], self.beams[1][ibeam]],
tilt_units="pixels",
qspace=True,
)
)
# Adjust intensity for correct normalization of S matrix rows
# taking into account the PRISM factor that needs to be applied
# when the Smatrix is evaluated (only 1/product(PRISM_factor)
# beams are taken and only 1/product(PRISM_factor) intensity
# is cropped out in real space)
psi *= torch.prod(torch.tensor(self.PRISM_factor, dtype=self.dtype))
if self.Fourier_space_output:
self.S[ibeam] = psi[self.bw_mapping[:, 0], self.bw_mapping[:, 1], :]
else:
self.S[ibeam] = fourier_interpolate_2d_torch(
psi,
self.stored_gridshape,
qspace_in=True,
qspace_out=False,
norm="conserve_norm",
)
self.initialized = True
# Make nslice_ which always accounts of subslices of the structure
if subslicing:
nslice_ = nslice
else:
nslice_ = nslice * self.nsubslices
# Work out direction of propagation through specimen
if nslice_ != self.current_slice:
direction = np.sign(nslice_ - self.current_slice)
else:
direction = 1
if direction == 0:
direction = 1
if nslice_ > len(self.seed):
# Add new seeds to determine random translations for frozen-phonon
# multislice (required for reversability of multislice) if required
self.seed = np.concatenate(
[
self.seed,
np.random.randint(0, 2 ** 31 - 1, size=nslice_ - len(self.seed)),
]
)
# Now generate list of slices that the multislice algorithm will run through
slices = np.arange(self.current_slice, nslice_, direction)
if direction < 0:
slices += direction
# For a transposed scattering matrix the order of the slices
# in multislice should be reversed
if self.transposed:
slices = slices[::-1]
# If streaming of Smatrix columns to the GPU is being used, ensure
# that propagators and transmission functions for the multislice are
# already on the GPU
if self.GPU_streaming:
propagators = ensure_torch_array(propagators).cuda()
transmission_functions = ensure_torch_array(transmission_functions).cuda()
self.current_slice = nslice_
if len(slices) < 1:
return
# Loop over the different plane wave components (or columns) of the
# scattering matrix
for i in tqdm(
range(int(np.ceil(self.nbeams / batch_size))),
disable=tdisable,
desc="Calculating S-matrix",
):
# Initialize array that will be used as input to the multislice routine
psi = torch.zeros(
batch_size, *self.gridshape, 2, dtype=self.dtype, device=self.device
)
beams = np.arange(
i * batch_size, min((i + 1) * batch_size, self.nbeams), dtype=np.int
)
if self.Fourier_space_output:
# Expand S-matrix input to full grid for multislice propagation
psi[
: beams.shape[0], self.bw_mapping[:, 0], self.bw_mapping[:, 1], :
] = self.S[beams]
else:
# Fourier interpolate stored real space S-matrix column onto
# multislice grid
psi = fourier_interpolate_2d_torch(
self.S[beams], self.gridshape, norm="conserve_norm"
)
if self.GPU_streaming:
psi = ensure_torch_array(psi, dtype=self.dtype).to("cuda")
output = multislice(
psi[: beams.shape[0]],
slices,
propagators,
transmission_functions,
self.tiling,
self.device,
self.seed,
return_numpy=False,
qspace_in=self.Fourier_space_output,
qspace_out=self.Fourier_space_output,
transpose=self.transposed,
output_to_bandwidth_limit=False,
reverse=direction < 0,
)
if self.GPU_streaming:
output = output.to(self.device)
if self.Fourier_space_output:
self.S[beams] = output[
:, self.bw_mapping[:, 0], self.bw_mapping[:, 1], :
] * np.sqrt(np.prod(self.stored_gridshape) / np.prod(self.gridshape))
else:
output = fourier_interpolate_2d_torch(
output, self.stored_gridshape, norm="conserve_norm"
)
self.S[beams] = output
def PRISM_crop_window(self, win=None, device=None):
"""Calculate 2D array indices of STEM crop window."""
device = get_device(device)
if win is None:
win = self.PRISM_factor
crop_ = [
torch.arange(
-self.stored_gridshape[i] // (2 * win[i]),
self.stored_gridshape[i] // (2 * win[i]),
device=device,
)
for i in range(2)
]
return crop_
def __call__(self, probes, nslices, posn=None, Smat=None, scan_transform=None):
"""
Calculate exit-surface waves function using the scattering matrix.
Parameters
----------
probes : (N,Y,X,2) torch.array
Input wave functions to calculate exit surface wave functions from
must be in Diffraction space
nslices :
Does nothing, only there to match call signature for STEM routine
posn : array_like (N,2)
Positions of
S : array_like (Nbeams,Y,X,2)
Scattering matrix object
Returns
-------
output : (N,Y,X,2) torch.array
Exit surface wave functions
"""
from copy import deepcopy
if Smat is None:
Smat = self.S
Sshape = [int(x) for x in Smat.shape]
device = Smat.device
crop_ = self.PRISM_crop_window(device=device)
# Ensure posn and probes are pytorch arrays
probes = ensure_torch_array(probes, dtype=self.dtype, device=device)
# Ensure probes tensors correspond to the shape N x Y x X x 2
# If they have the shape Y x X x 2 then reshape to 1 x Y x X x 2
if probes.ndim < 4:
probes = probes.view(1, *probes.shape)
# Get number of probes
nprobes = probes.shape[0]
if posn is None:
posn = torch.zeros(nprobes, 2, device=device, dtype=self.dtype)
else:
posn = torch.as_tensor(posn, device=device, dtype=self.dtype).view(
(nprobes, 2)
)
if scan_transform is not None:
posn = scan_transform(posn)
# TODO decide whether to remove the Fourier_space_output option
if self.Fourier_space_output:
# A note on normalization: an individual probe enters the STEM routine
# with sum_squared intensity of 1, but the STEM routine applies an
# FFT so the sum_squared intensity is now equal to # pixels
# For a correct matrix multiplication we must now divide by sqrt(# pixels)
probe_vec = complex_matmul(
probes[:, self.beams[0], self.beams[1]], Smat
) / np.sqrt(np.prod(self.gridshape))
# Now reshape output from vectors to square arrays
probes = torch.zeros(
nprobes, *self.stored_gridshape, 2, dtype=self.dtype, device=self.device
)
probes[:, self.bw_mapping[:, 0], self.bw_mapping[:, 1], :] = probe_vec
# Apply PRISM cropping in real space if appropriate
if self.doPRISM:
shape = probes.size()
probes = torch.ifft(probes, signal_ndim=2).flatten(-3, -2)
for k in range(nprobes):
# Calculate windows in vertical and horizontal directions
window = crop_window_to_flattened_indices_torch(
[
(crop_[i] + posn[k, i] * self.stored_gridshape[i])
% self.stored_gridshape[i]
for i in range(2)
],
self.stored_gridshape,
)
probe = deepcopy(probes[k])
probes[k] = 0
probes[k, window, :] = probe[window, :]
probes = probes.reshape(shape)
# Transform probe back to Fourier space
return torch.fft(probes, signal_ndim=2)
return probes
else:
# Flatten the array dimensions
Smatshape = Smat.shape
flattened_shape = [Smatshape[0], Smatshape[-3] * Smatshape[-2], 2]
N = probes.shape[0]
output = torch.zeros(
N, *Smat.shape[-3:-1], 2, dtype=self.dtype, device=Smat.device
)
# For evaluating the probes in real space we only want to perform the matrix
# multiplication and summation within the real space PRISM cropping region
stride = [x // y for x, y in zip(self.stored_gridshape, self.PRISM_factor)]
halfstride = [x // 2 for x in stride]
# for k in range(probes.size(0)):
for probe, pos, out in zip(probes, posn, output):
if self.doPRISM:
start = [
int(torch.round(pos[i] * Sshape[-3 + i])) - halfstride[i]
for i in range(2)
]
windows = crop_window_to_periodic_indices(
[start[0], stride[0], start[1], stride[1]], Sshape[-3:-1]
)
for wind in windows:
outview = out.narrow(-3, wind[0][0], wind[0][1]).narrow(
-2, wind[1][0], wind[1][1]
)
sview = Smat.narrow(-3, wind[0][0], wind[0][1]).narrow(
-2, wind[1][0], wind[1][1]
)
p = probe[self.beams[0], self.beams[1]].view(
self.nbeams, 1, 1, 2
)
outview += torch.sum(complex_mul(p, sview), axis=0)
else:
output += complex_matmul(
probe[self.beams[0], self.beams[1]], Smat.view(flattened_shape)
).view(Smatshape[1:])
output /= np.sqrt(np.prod(probes.size()[-3:-1]))
output = crop_torch(
output.reshape(probes.size(0), *Smat.size()[-3:]), self.stored_gridshape
)
return torch.fft(output, signal_ndim=2)
def STEM_with_GPU_streaming(
self,
detectors=None,
FourD_STEM=None,
datacube=None,
STEM_image=None,
nstreams=None,
df=0,
aberrations=[],
ROI=[0.0, 0.0, 1.0, 1.0],
device=None,
scan_posns=None,
showProgress=True,
):
"""
Perform STEM with scattering matrix streamed between RAM and GPU memory.
This allows much larger fields of view to be calculated with relatively
modest graphics card memory. The STEM raster is segmented into spatially
close clusters and the probe positions in these clusters are processed
sequentially, with the relevant part of the scattering matrix streamed
from CPU to GPU memory.
Parameters
----------
self : scattering_matrix
The scattering matrix object.
detectors : (Ndet, Y, X) array_like, optional
Diffraction plane detectors to perform conventional STEM imaging. If
None is passed then no conventional STEM images will be returned.
fourD_STEM : bool or array_like, optional
Pass fourD_STEM = True to perform 4D-STEM simulations. To save disk
space a tuple containing pixel size and diffraction space extent of the
datacube can be passed in. For example ([64,64],[1.2,1.2]) will output
diffraction patterns measuring 64 x 64 pixels and 1.2 x 1.2 inverse
Angstroms.
datacube : (Ny, Nx, Y, X) array_like, optional
datacube for 4D-STEM output, if None is passed (default) this will be
initialized in the function. If a datacube is passed then the result
will be added by the STEM routine (useful for multiple frozen phonon
iterations)
STEM_image : (Ndet,Ny,Nx) array_like, optional
Array that will contain the conventional STEM images, if not passed
will be initialized within the function. If it is passed then the result
will be accumulated within the function, which is useful for multiple
frozen phonon iterations.
nstreams : int, optional
Number of streams (seperate transfers from CPU to GPU memory). If
None this will just be set to the product of the PRISM interpolation
factor
df : float, optional
Defocus in Angstrom
aberrations : list, optional
A list containing a set of the class aberration, pass an empty list for
an unaberrated contrast transfer function.
ROI : (4,) array_like
Fraction of the unit cell to be scanned. Should contain [y0,x0,y1,x1]
where [x0,y0] and [x1,y1] are the bottom left and top right coordinates
of the region of interest (ROI) expressed as a fraction of the total
grid (or unit cell).
device : torch.device, optional
torch.device object which will determine which device (CPU or GPU) the
calculations will run on.
scan_posn : (...,2) array_like, optional
Array containing the STEM scan positions in fractional coordinates.
If provided scan_posn.shape[:-1] will give the shape of the STEM
image.
showProgress : str or bool, optional
Pass False to disable progress readout, pass 'notebook' to get correct
progress bar behaviour inside a jupyter notebook
"""
device = get_device(device)
tdisable, tqdm = tqdm_handler(showProgress)
# Get indices of PRISM cropping window
crop_ = [x.cpu().numpy() for x in self.PRISM_crop_window()]
# Make the STEM probe
probe = focused_probe(
self.gridshape,
self.rsize[:2],
self.eV,
self.alpha_,
df=df,
aberrations=aberrations,
app_units="invA",
)
probe = cx_from_numpy(probe, device=device, dtype=self.dtype)
# Make scan positions if none already provided
if scan_posns is None:
scan_posns = generate_STEM_raster(
self.rsize, self.eV, self.alpha_, tiling=self.tiling, ROI=ROI, invA=True
)
# Get scan (and STEM image) array shape and total number of scan positions
scan_shape = scan_posns.shape[:-1]
nscan = np.product(scan_shape)
# Flatten scan positions to simplify iteration later on.
scan_posns = scan_posns.reshape((nscan, 2))
# Calculate default 4D-STEM diffraction pattern sampling
if FourD_STEM is True:
GS = self.stored_gridshape
FourD_STEM = [GS, GS / self.rsize[:2]]
# Allocate diffraction pattern and STEM images if not already provided
if FourD_STEM:
gridout = workout_4DSTEM_datacube_DP_size(
FourD_STEM, self.rsize, self.gridshape
)[0]
if (datacube is None) and FourD_STEM:
datacube = np.zeros((*scan_shape, *gridout))
if not FourD_STEM:
datacube = None
# If detectors are provided then we are doing conventional STEM
doConventionalSTEM = detectors is not None
# Initialize STEM images if not provided
if doConventionalSTEM:
ndet = detectors.shape[0]
if STEM_image is None:
STEM_image = np.zeros((ndet, nscan))
else:
STEM_image = STEM_image.reshape((ndet, nscan))
else:
STEM_image = None
if nstreams is None:
# If the number of seperate streams is not suggested by the
# user, make this equal to the product of the PRISM factor
nstreams = int(np.product(self.PRISM_factor))
# Divide up the scan positions into clusters based on Euclidean
# distance
from sklearn.cluster import Birch
if nscan > 1:
model = Birch(threshold=0.01, n_clusters=nstreams)
yhat = model.fit_predict(scan_posns)
clusters = np.unique(yhat)
else:
yhat, clusters = [[0], [0]]
# Now do STEM with each of the scan position clusters, streaming
# only the necessary bits of the scattering matrix to the GPU.
Datacube_segment = None
STEM_image_segment = None
FlatS = self.S.reshape((self.nbeams, np.prod(self.stored_gridshape), 2))
# Loop over probe positions clusters. This would be a good candidate
# for multi-GPU work.
for cluster in tqdm(clusters, desc="Probe position clusters", disable=tdisable):
# Get map of probe positions in cluster
points = np.nonzero(yhat == cluster)[0]
npoints = len(points)
# Get segments of images to update
if doConventionalSTEM:
STEM_image_segment = STEM_image[:, points]
if FourD_STEM:
Datacube_segment = np.zeros((1, npoints, 1, *gridout))
pix_posn = scan_posns[points] * np.asarray(self.stored_gridshape)
# Work out bounds of the rectangular region of the scattering
# matrix to stream to the GPU
ymin, ymax = [
int(np.floor(np.amin(pix_posn[:, 0]) + crop_[0][0])),
int(np.ceil(np.amax(pix_posn[:, 0]) + crop_[0][-1])),
]
xmin, xmax = [
int(np.floor(np.amin(pix_posn[:, 1]) + crop_[1][0])),
int(np.ceil(np.amax(pix_posn[:, 1]) + crop_[1][-1])),
]
size = np.asarray([(ymax - ymin), (xmax - xmin)])
# Get indices of region of scattering matrix to stream to GPU
window = [np.arange(a, b) for a, b in zip([ymin, xmin], [ymax, xmax])]
indices = crop_window_to_flattened_indices_torch(
window, self.stored_gridshape
)
# Get segment of the scattering matrix to stream to GPU
segmentshape = [len(x) for x in window]
SegmentS = FlatS[:, indices, :].reshape((self.nbeams, *segmentshape, 2))
# Define a function that will map probe positions for the global
# scattering matrix to their correct place on the smaller scattering
# matrix streamed to the GPU.
gshape = torch.as_tensor(self.stored_gridshape).to(device).type(self.dtype)
Origin = torch.as_tensor([ymin, xmin]).to(device).type(self.dtype)
segment_size = torch.as_tensor(size).to(device).type(self.dtype)
def scan_transform(posn):
return (posn * gshape - Origin) / segment_size
# Keyword arguments to be passed to the __call__ function by the
# STEM routine
kwargs = {"Smat": SegmentS.to(device), "scan_transform": scan_transform}
# Calculate STEM images
STEM(
self.rsize,
probe,
self.__call__,
[1],
self.eV,
self.alpha_,
detectors=detectors,
FourD_STEM=FourD_STEM,
datacube=Datacube_segment,
scan_posn=scan_posns[points].reshape((npoints, 1, 2)),
STEM_image=STEM_image_segment,
method_kwargs=kwargs,
showProgress=False,
device=device,
)
if doConventionalSTEM:
STEM_image[:, points] += STEM_image_segment
if FourD_STEM:
for point, Dp in zip(points, Datacube_segment[0]):
y, x = np.unravel_index(point, scan_shape)
datacube[y, x] += Dp[0]
# Unflatten 4D-STEM datacube scan dimensions, use numpy squeeze to
# remove superfluous dimensions (ones with length 1)
# if FourD_STEM:
# datacube = datacube.reshape(*scan_shape, *datacube.shape[-2:])
if doConventionalSTEM:
STEM_image = np.squeeze(STEM_image.reshape(ndet, *scan_shape))
# Return STEM images and datacube as a dictionary. If either of these
# objects were not calculated the dictionary will contain None for those
# entries.
return {"STEM images": STEM_image, "datacube": datacube}
def phase_from_com(com, reg=1e-10, rsize=[1, 1]):
"""
Integrate 4D-STEM centre of mass (DPC) measurements to calculate object phase.
Assumes a three dimensional array com, with the final two dimensions
corresponding to the image and the first dimension of the array corresponding
to the y and x centre of mass respectively.
"""
# Get shape of arrays
ny, nx = com.shape[1:]
s = (ny, nx)
d = np.asarray(rsize) / np.asarray([ny, nx])
# Calculate Fourier coordinates for array
ky = np.fft.fftfreq(ny, d=d[0])
kx = np.fft.rfftfreq(nx, d=d[1])
# Calculate numerator and denominator expressions for solution of
# phase from centre of mass measurements
numerator = ky[:, None] * np.fft.rfft2(com[0], s=s) + kx[None, :] * np.fft.rfft2(
com[1], s=s
)
denominator = 1j * ((kx ** 2)[None, :] + (ky ** 2)[:, None]) + reg
# Avoid a divide by zero for the origin of the Fourier coordinates
numerator[0, 0] = 0
denominator[0, 0] = 1
# Return real part of the inverse Fourier transform
return np.fft.irfft2(numerator / denominator, s=s)
Functions
def STEM(rsize, probe, method, nslices, eV, alpha, batch_size=1, detectors=None, FourD_STEM=False, datacube=None, PACBED=False, scan_posn=None, dtype=torch.float32, device=None, tiling=[1, 1], seed=None, showProgress=True, method_args=(), method_kwargs={}, STEM_image=None)
-
Perform a scanning transmission electron microscopy (STEM) image simulation.
Will return an array containing conventional STEM images and/or a 4D-STEM datacube depending on inputs
Parameters
rsize
:(2,) array_like
- The real space size of the grid in Angstroms
probe
:(Y,X) array_like
- The probe that will be rastered over the object
method
:function
- A function that takes a probe and propagates it to the exit surface of the specimen
nslices
:int, array_like
- The number of slices to perform multislice over
eV
:float
- Accelerating voltage of the probe, needed to work out probe sampling requirements
alpha
:float
- The convergence angle of the probe in mrad, needed to work out probe sampling requirements
batch_size
:int
, optional- The multislice algorithm can be performed on multiple probes columns at once to parrallelize computation, the number of parrallel computations is set by batch_size.
detectors
:(Ndet, Y, X) array_like
, optional- Diffraction plane detectors to perform conventional STEM imaging. If None is passed then no conventional STEM images will be returned.
fourD_STEM
:bool
orarray_like
, optional- Pass fourD_STEM = True to perform 4D-STEM simulations. To save disk space a tuple containing pixel size and diffraction space extent of the datacube can be passed in. For example ([64,64],[1.2,1.2]) will output diffraction patterns measuring 64 x 64 pixels and 1.2 x 1.2 inverse Angstroms.
datacube
:(Ny, Nx, Y, X) array_like
, optional- datacube for 4D-STEM output, if None is passed (default) this will be initialized in the function. If a datacube is passed then the result will be added by the STEM routine (useful for multiple frozen phonon iterations)
PACBED
:bool
- If True the STEM function will calculate a position averaged convergent electron diffraction (PABCED) pattern by averaging the diffraction space
scan_posn
:(…,2) array_like
, optional- Array containing the STEM scan positions in fractional coordinates. If provided scan_posn.shape[:-1] will give the shape of the STEM image. result over all scan positions
dtype
:torch.dtype
, optional- Datatype of the simulation arrays, by default 32-bit floating point
device
:torch.device
, optional- torch.device object which will determine which device (CPU or GPU) the calculations will run on
tiling
:(2,) array_like
, optional- Tiling of a repeat unit cell on simulation grid, STEM raster will only scan a single unit cell.
seed
:array_like
orint
, optional- Seed for the random number generator for frozen phonon configurations
showProgress
:str
orbool
, optional- Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook
method_args
:list
, optional- Arguments for the method function used to propagate probes to the exit surface
method_kwargs
:Dict
, optional- Keyword arguments for the method function used to propagate probes to the exit surface
STEM_image
:(Ndet,Ny,Nx) array_like
, optional- Array that will contain the conventional STEM images, if not passed will be initialized within the function. If it is passed then the result will be accumulated within the function, which is useful for multiple frozen phonon iterations.
Returns
Result
:dict
- A dictionary with the keys "STEM images", "datacube" and "PACBED" which contain the conventional STEM images, the 4D-STEM datacube and the PACBED pattern respectively. If any of these simulations where not performed then the relevant entry will just contain None
Expand source code
def STEM( rsize, probe, method, nslices, eV, alpha, batch_size=1, detectors=None, FourD_STEM=False, datacube=None, PACBED=False, scan_posn=None, dtype=torch.float32, device=None, tiling=[1, 1], seed=None, showProgress=True, method_args=(), method_kwargs={}, STEM_image=None, ): """ Perform a scanning transmission electron microscopy (STEM) image simulation. Will return an array containing conventional STEM images and/or a 4D-STEM datacube depending on inputs Parameters ---------- rsize : (2,) array_like The real space size of the grid in Angstroms probe : (Y,X) array_like The probe that will be rastered over the object method : function A function that takes a probe and propagates it to the exit surface of the specimen nslices : int, array_like The number of slices to perform multislice over eV : float Accelerating voltage of the probe, needed to work out probe sampling requirements alpha : float The convergence angle of the probe in mrad, needed to work out probe sampling requirements batch_size : int, optional The multislice algorithm can be performed on multiple probes columns at once to parrallelize computation, the number of parrallel computations is set by batch_size. detectors : (Ndet, Y, X) array_like, optional Diffraction plane detectors to perform conventional STEM imaging. If None is passed then no conventional STEM images will be returned. fourD_STEM : bool or array_like, optional Pass fourD_STEM = True to perform 4D-STEM simulations. To save disk space a tuple containing pixel size and diffraction space extent of the datacube can be passed in. For example ([64,64],[1.2,1.2]) will output diffraction patterns measuring 64 x 64 pixels and 1.2 x 1.2 inverse Angstroms. datacube : (Ny, Nx, Y, X) array_like, optional datacube for 4D-STEM output, if None is passed (default) this will be initialized in the function. If a datacube is passed then the result will be added by the STEM routine (useful for multiple frozen phonon iterations) PACBED : bool If True the STEM function will calculate a position averaged convergent electron diffraction (PABCED) pattern by averaging the diffraction space scan_posn : (...,2) array_like, optional Array containing the STEM scan positions in fractional coordinates. If provided scan_posn.shape[:-1] will give the shape of the STEM image. result over all scan positions dtype : torch.dtype, optional Datatype of the simulation arrays, by default 32-bit floating point device : torch.device, optional torch.device object which will determine which device (CPU or GPU) the calculations will run on tiling : (2,) array_like, optional Tiling of a repeat unit cell on simulation grid, STEM raster will only scan a single unit cell. seed : array_like or int, optional Seed for the random number generator for frozen phonon configurations showProgress : str or bool, optional Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook method_args : list, optional Arguments for the method function used to propagate probes to the exit surface method_kwargs : Dict, optional Keyword arguments for the method function used to propagate probes to the exit surface STEM_image : (Ndet,Ny,Nx) array_like, optional Array that will contain the conventional STEM images, if not passed will be initialized within the function. If it is passed then the result will be accumulated within the function, which is useful for multiple frozen phonon iterations. Returns ------- Result : dict A dictionary with the keys "STEM images", "datacube" and "PACBED" which contain the conventional STEM images, the 4D-STEM datacube and the PACBED pattern respectively. If any of these simulations where not performed then the relevant entry will just contain None """ from .utils.torch_utils import detect tdisable, tqdm = tqdm_handler(showProgress) # Get number of thicknesses in the series nthick = len(nslices) if isinstance(nslices[0], int): nslices_ = np.diff(nslices, prepend=0) else: nslices_ = nslices if device is None: device = get_device(device) # Get shape of grid if torch.is_tensor(probe): gridshape = probe.shape[-3:-1] else: gridshape = probe.shape[-2:] # Generate scan positions in units of pixels if not supplied if scan_posn is None: scan_posn = generate_STEM_raster(rsize[:2], eV, alpha, tiling) # Number of scan positions scan_shape = scan_posn.shape[:-1] nscantot = np.prod(scan_shape) scan_posn = scan_posn.reshape((nscantot, 2)) # Ensure scan_posn is a pytorch tensor with same device and datatype as other # arrays scan_posn = torch.as_tensor(scan_posn).to(device).type(dtype) # Assume real space probe is passed in so perform Fourier transform in # anticipation of application of Fourier shift theorem probe_ = ensure_torch_array(probe, dtype=dtype, device=device) probe_ = torch.fft(probe_, signal_ndim=2) # Work out whether to perform conventional STEM or not conventional_STEM = detectors is not None if conventional_STEM: # Get number of detectors ndet = detectors.shape[0] # Initialize array in which to store resulting STEM images if STEM_image is None: STEM_image = np.zeros((ndet, nthick, nscantot)) else: STEM_image = STEM_image.reshape((ndet, nthick, nscantot)) # Also move detectors to pytorch if necessary D = ensure_torch_array(detectors, device=device, dtype=dtype) else: STEM_image = None # Initialize array in which to store resulting 4D-STEM datacube if required if FourD_STEM: # Get diffraction pattern gridsize in pixels from input and function # to resample the simulation output to store in the datacube gridout, resize, _ = workout_4DSTEM_datacube_DP_size( FourD_STEM, rsize, gridshape ) # Check whether a datacube is already provided or not if datacube is None: datacube = np.zeros((nthick, *scan_shape, *gridout)) if PACBED: PACBED_pattern = torch.zeros((nthick, *gridshape), device=device) else: PACBED_pattern = None # This algorithm allows for "batches" of probe to be sent through the # multislice algorithm to achieve some speed up at the cost of storing more # probes in memory if seed is None and batch_size > 1: # If no seed passed to random number generator then make one to pass to # the multislice algorithm. This ensure that each probe sees the same # frozen phonon configuration if we are doing batched multislice # calculations seed = np.random.randint(0, 2 ** 31 - 1) for i in tqdm( range(int(np.ceil(nscantot / batch_size))), disable=tdisable, desc="Probe positions", ): # Make shifted probes scan_index = np.arange( i * batch_size, min((i + 1) * batch_size, nscantot), dtype=np.int ) # The shift operator array array will be of size batch_size x Y x X probes = fourier_shift_array( gridshape, torch.as_tensor(scan_posn[scan_index]), dtype=dtype, device=device, units="fractional", ) # Apply shift to original probe probes = complex_mul(probe_.view(1, *probe_.size()), probes) # Thickness series # - need to take the difference between sequential thickness variations for it, t in enumerate(nslices_): # Evaluate exit surface wave function from input probes probes = method( probes, t, *method_args, posn=scan_posn[scan_index], **method_kwargs ) # Calculate amplitude of probes, a real output is assumed to be the # amplitude of the exit surface wave function. Also correct # normalization to be in units of fractional intensity cmplxout = iscomplex(probes) if cmplxout: amp = amplitude(probes) / np.prod(probes.size()[-3:-1]) else: amp = probes / np.prod(probes.size()[-2:]) # Calculate STEM images if conventional_STEM: # broadcast detector and probe arrays to # ndet x batch_size x Y x X and reduce final two dimensions STEM_image[:ndet, it, scan_index] += detect(D, amp).cpu().numpy() # Store datacube if FourD_STEM: DPS = resize(amp.cpu().numpy()) ind = np.unravel_index(scan_index, scan_shape) for idp, DP in enumerate(DPS): datacube[it][ind[0][idp], ind[1][idp]] += DP if PACBED: PACBED_pattern[it] += torch.sum(amp, axis=0) / nscantot # In some cases the amplitude will be returned by the function # in this case multiple thickness values should not be used! This # break command helps prevent misuse. if not cmplxout: break if conventional_STEM: STEM_image = np.squeeze(STEM_image.reshape(ndet, nthick, *scan_shape)) if PACBED: PACBED_pattern = np.fft.fftshift(PACBED_pattern.cpu().numpy(), axes=(-2, -1)) return {"STEM images": STEM_image, "datacube": datacube, "PACBED": PACBED_pattern}
def STEM_phase_contrast_transfer_function(probe, detector)
-
Calculate the STEM phase contrast transfer function.
For a thin and weakly scattering sample convolution with the STEM contrast transfer function gives a good approximate for STEM image contrast.
Parameters
probe
:complex, (Y,X) array_like
- The STEM probe in reciprocal space
detector
:real, (Y,X) array_like
- The STEM detector
Returns
PCTF
:(Y,X) np.ndarray
- The phase contrast transfer function
Expand source code
def STEM_phase_contrast_transfer_function(probe, detector): """ Calculate the STEM phase contrast transfer function. For a thin and weakly scattering sample convolution with the STEM contrast transfer function gives a good approximate for STEM image contrast. Parameters ---------- probe : complex, (Y,X) array_like The STEM probe in reciprocal space detector : real, (Y,X) array_like The STEM detector Returns ------- PCTF : (Y,X) np.ndarray The phase contrast transfer function """ from .utils import convolve norm = np.sum(np.square(np.abs(probe))) # Use two ffts to perform reflection k -> -k PCTF = ( convolve( probe, np.fft.fft2( np.fft.fft2(np.conj(probe) * detector, norm="ortho"), norm="ortho" ), ) / norm ) PCTF -= np.conj(np.fft.fft2(np.fft.fft2(PCTF, norm="ortho"), norm="ortho")) return -2 * np.imag(PCTF)
def generate_STEM_raster(rsize, eV, alpha, tiling=[1, 1], ROI=[0.0, 0.0, 1.0, 1.0], gridshape=[1, 1], invA=False)
-
Return the probe positions for a nyquist-sampled STEM raster.
For a real space size rsize return probe positions in units of fraction of the array for nyquist sampled STEM raster
Parameters
rsize
:(2,) array_like
- Size of the grid in real space in units of Angstroms
eV
:float
- Probe energy in electron volts
alpha
:float
- Probe forming aperture semi-angle in mrad or inverse Angstorm (if invA == True)
gridshape
:(2,) array_like
, optional- Pixel dimensions of the 2D grid, by default [1,1] so probe positions will be returned as a fraction of the array size
tiling
:(2,) array_like
- Tiling of a repeat unit cell on simulation grid, if provided STEM raster will only scan a single unit cell.
ROI
:(4,) array_like
- Fraction of the unit cell to be scanned. Should contain [y0,x0,y1,x1] where [x0,y0] and [x1,y1] are the bottom left and top right coordinates of the region of interest (ROI) expressed as a fraction of the total grid (or unit cell).
invA
:bool
- If True, alpha is taken to be in units of inverse Angstrom, not mrad. This also means that the value of eV no longer matters
Returns
probe_posns
:(nY,nX,2) np.ndarray
- The probe positions in fractions of the array if gridshape is [1,1] and in pixel units if gridshape is the size of the pixel array.
Expand source code
def generate_STEM_raster( rsize, eV, alpha, tiling=[1, 1], ROI=[0.0, 0.0, 1.0, 1.0], gridshape=[1, 1], invA=False, ): """ Return the probe positions for a nyquist-sampled STEM raster. For a real space size rsize return probe positions in units of fraction of the array for nyquist sampled STEM raster Parameters ---------- rsize : (2,) array_like Size of the grid in real space in units of Angstroms eV : float Probe energy in electron volts alpha : float Probe forming aperture semi-angle in mrad or inverse Angstorm (if invA == True) gridshape : (2,) array_like, optional Pixel dimensions of the 2D grid, by default [1,1] so probe positions will be returned as a fraction of the array size tiling : (2,) array_like Tiling of a repeat unit cell on simulation grid, if provided STEM raster will only scan a single unit cell. ROI : (4,) array_like Fraction of the unit cell to be scanned. Should contain [y0,x0,y1,x1] where [x0,y0] and [x1,y1] are the bottom left and top right coordinates of the region of interest (ROI) expressed as a fraction of the total grid (or unit cell). invA : bool If True, alpha is taken to be in units of inverse Angstrom, not mrad. This also means that the value of eV no longer matters Returns ------- probe_posns : (nY,nX,2) np.ndarray The probe positions in fractions of the array if gridshape is [1,1] and in pixel units if gridshape is the size of the pixel array. """ # Field of view in Angstrom FOV = np.asarray([rsize[0] * (ROI[2] - ROI[0]), rsize[1] * (ROI[3] - ROI[1])]) if invA: # Number of scan coordinates in each dimension nscan = nyquist_sampling(FOV / np.asarray(tiling), resolution_limit=alpha) else: # Number of scan coordinates in each dimension nscan = nyquist_sampling(FOV / np.asarray(tiling), eV=eV, alpha=alpha) # Generate Y and X scan coordinates yy, xx = [ np.arange( ROI[0 + i] * gridshape[i] / tiling[i], ROI[2 + i] * gridshape[i] / tiling[i], step=np.diff(ROI[i::2])[0] * gridshape[i] / nscan[i] / tiling[i], )[: nscan[i]] / gridshape[i] for i in range(2) ] return np.stack(np.broadcast_arrays(yy[:, None], xx[None, :]), axis=2)
def generate_probe_spread_plot(gridshape, structure, eV, app, thickness, subslices=[1], tiling=[1, 1], showcrossection=True, df=0, probe_posn=[0, 0], show=True, device=None, P=None, T=None, nslices=None)
-
Generate probe spread plot to assist with selection of appropriate multislice grid.
A multislice calculation assumes periodic boundary conditions. To avoid artefacts associated with this the pixel grid must be chosen to have sufficient size so that the probe does not artificially interfere with itself through the periodic boundary (wrap around error). The grid sampling must also be sufficient that electrons scattered to high angles are not scattered beyond the band-width limit of the array.
The probe spread plot helps identify whenever these two events are happening. If the probe intensity drops below 0.95 (as a fraction of initial intensity) then the grid is not sampled finely enough, the pixel size of the array (gridshape) needs to increased for finer sampling of the specimen potential. If the probe spread exceeds 0.2 (as a fraction of the array) then too much of the probe is spreading to the edges of the array, the real space size of the array (usually controlled by the tiling of the unit cell) needs to be increased.
Parameters
gridshape
:(2,) array_like
- Pixel dimensions of the 2D grid
structure
:structure
- The structure of interest
eV
:float
- Probe energy in electron volts
app
:float
- Probe-forming aperture in mrad
thickness
:float
- The maximum thickness of the simulation object in Angstrom
subslices
:array_like
, optional- A one dimensional array-like object containing the depths (in fractional coordinates) at which the object will be subsliced. The last entry should always be 1.0. For example, to slice the object into four equal sized slices pass [0.25,0.5,0.75,1.0]
tiling
:(2,) array_like
, optional- Tiling of a repeat unit cell on simulation grid
showcrossection
:bool
- Pass True to plot the projected cross section of the probe to inspect the spread.
df
:float
- Probe defocus in Angstrom
probe_posn
:array_like
, optional- Probe position as a fraction of the unit-cell
P
:(n,Y,X) array_like
, optional- Precomputed Fresnel free-space propagators
T
:(n,Y,X) array_like
- Precomputed transmission functions
Returns
fig
:matplotlib.figure object
- The figure on which the probe spread is plotted
Expand source code
def generate_probe_spread_plot( gridshape, structure, eV, app, thickness, subslices=[1], tiling=[1, 1], showcrossection=True, df=0, probe_posn=[0, 0], show=True, device=None, P=None, T=None, nslices=None, ): """ Generate probe spread plot to assist with selection of appropriate multislice grid. A multislice calculation assumes periodic boundary conditions. To avoid artefacts associated with this the pixel grid must be chosen to have sufficient size so that the probe does not artificially interfere with itself through the periodic boundary (wrap around error). The grid sampling must also be sufficient that electrons scattered to high angles are not scattered beyond the band-width limit of the array. The probe spread plot helps identify whenever these two events are happening. If the probe intensity drops below 0.95 (as a fraction of initial intensity) then the grid is not sampled finely enough, the pixel size of the array (gridshape) needs to increased for finer sampling of the specimen potential. If the probe spread exceeds 0.2 (as a fraction of the array) then too much of the probe is spreading to the edges of the array, the real space size of the array (usually controlled by the tiling of the unit cell) needs to be increased. Parameters ---------- gridshape : (2,) array_like Pixel dimensions of the 2D grid structure : pyms.structure_routines.structure The structure of interest eV : float Probe energy in electron volts app : float Probe-forming aperture in mrad thickness : float The maximum thickness of the simulation object in Angstrom subslices : array_like, optional A one dimensional array-like object containing the depths (in fractional coordinates) at which the object will be subsliced. The last entry should always be 1.0. For example, to slice the object into four equal sized slices pass [0.25,0.5,0.75,1.0] tiling : (2,) array_like, optional Tiling of a repeat unit cell on simulation grid showcrossection : bool Pass True to plot the projected cross section of the probe to inspect the spread. df : float Probe defocus in Angstrom probe_posn : array_like, optional Probe position as a fraction of the unit-cell P : (n,Y,X) array_like, optional Precomputed Fresnel free-space propagators T : (n,Y,X) array_like Precomputed transmission functions Returns ------- fig : matplotlib.figure object The figure on which the probe spread is plotted """ # Calculate multislice propagator and transmission functions from .Premixed_routines import multislice_precursor if P is None or T is None: P, T = multislice_precursor( structure, gridshape, eV, subslices=subslices, tiling=tiling, device=device, nT=1, showProgress=False, ) # Calculate focused STEM probe probe = focused_probe( gridshape, structure.unitcell[:2] * np.asarray(tiling), eV, app, df=df ) pos = np.asarray(probe_posn) / np.asarray(tiling) from .utils import fourier_shift, fourier_interpolate_2d probe = fourier_shift(probe, pos, pixel_units=False) ncols = 1 + showcrossection fig, ax = plt.subplots(ncols=ncols, figsize=(ncols * 4, 4), squeeze=False) # Total number of slices (not including subslicing of structure) if nslices is None: nslices = int(np.ceil(thickness / structure.unitcell[2])) # Total number of slices (including subslicing of structure) maxslices = nslices * len(subslices) variances = np.zeros(maxslices) intensity = np.zeros(maxslices) crossection = np.zeros((maxslices, gridshape[0])) # Array must be shifted to center probe position shift = (pos * np.asarray(gridshape)).astype(np.int) for i in range(maxslices): probe = multislice( probe, [1], P, T, tiling=tiling, subslicing=True, output_to_bandwidth_limit=False, device_type=device, ) mod = np.roll(np.abs(probe) ** 2, shift=-shift, axis=(-2, -1)) # Record probe intensity and spread intensity[i] = np.sum(mod) variances[i] = second_moment(mod) if showcrossection: crossection[i] = np.sum(mod, axis=-1) thicknesses = structure.unitcell[2] * ( np.broadcast_to(np.arange(nslices)[:, None], (nslices, len(subslices))).ravel() + np.tile(subslices, nslices) ) ax[0, 0].set_xlim([0, thicknesses[-1]]) ax[0, 0].set_ylim([0, 1.1]) ax[0, 0].set_ylabel( "$\\sqrt{\\int \\Psi^2 dx}$", color="red" ) # we already handled the x-label with ax1 ax[0, 0].set_xlabel(r"Depth of propagation ($\AA$)") ax[0, 0].tick_params(axis="y", labelcolor="red") ax[0, 0].set_title("Probe intensity and spread") ax2 = ax[0, 0].twinx() # instantiate a second axes that shares the same x-axis print(thicknesses.shape, variances.shape) ax2.plot(thicknesses, variances, "b-") ax2.tick_params(axis="y", labelcolor="b") ax2.plot([0, thickness], [0.2, 0.2], "b--") ax2.set_ylim([0, 0.5]) ax2.set_ylabel("$\\sqrt{\\int \\Psi^2 x^2 dx}$", color="blue") ax[0, 0].plot(thicknesses, intensity, "r-") ax[0, 0].plot([0, thicknesses[-1]], [0.95, 0.95], "r--") nz, ny = crossection.shape if showcrossection: ax[0, 1].imshow( fourier_interpolate_2d( np.fft.fftshift(np.sqrt(crossection), axes=1), [ny, ny] ), extent=[0, gridshape[0], thickness, 0], cmap=plt.get_cmap("gnuplot"), ) ax[0, 1].set_ylabel(r"Depth of propagation ($\AA$)") ax[0, 1].set_title("Probe depth cross-section") fig.tight_layout() if show: plt.show(block=True) return fig
def generate_slice_indices(nslices, nsubslices, subslicing=False)
-
Generate the slice indices for the multislice routine.
Expand source code
def generate_slice_indices(nslices, nsubslices, subslicing=False): """Generate the slice indices for the multislice routine.""" from collections.abc import Sequence if isinstance(nslices, Sequence) or isinstance(nslices, np.ndarray): # If a list is passed, continue as before return nslices else: # If an integer is passed generate a list of slices to iterate through niterations = nslices if subslicing else nslices * nsubslices return np.arange(niterations)
def make_detector(gridshape, rsize, eV, betamax, betamin=0, units='mrad')
-
Make a STEM detector with acceptance angle between betamin and betamax.
Parameters
gridshape
:(2,) array_like
- Pixel dimensions of the 2D grid
rsize
:(2,) array_like
- Size of the grid in real space in units of Angstroms
eV
:float
- Probe energy in electron volts
betamax
:float
- Detector outer acceptance semi-angle
betamin
:float
, optional- Detector inner acceptance semi-angle
units
:float
, optional- Units of betamin and betamax, mrad or invA are both acceptable
Returns
D
:(ndet,Y,X) array_like
- The detector functions
Expand source code
def make_detector(gridshape, rsize, eV, betamax, betamin=0, units="mrad"): """ Make a STEM detector with acceptance angle between betamin and betamax. Parameters ---------- gridshape : (2,) array_like Pixel dimensions of the 2D grid rsize : (2,) array_like Size of the grid in real space in units of Angstroms eV : float Probe energy in electron volts betamax : float Detector outer acceptance semi-angle betamin : float, optional Detector inner acceptance semi-angle units : float, optional Units of betamin and betamax, mrad or invA are both acceptable Returns ------- D : (ndet,Y,X) array_like The detector functions """ # Get reciprocal space array q = q_space_array(gridshape, rsize) # If units are mrad convert qspace array from inverse Angstrom to mrad if units == "mrad": q /= wavev(eV) / 1000 # Calculate modulus square of reciprocal space array qsq = np.square(q[0]) + np.square(q[1]) # Make detector detector = np.logical_and(qsq < betamax ** 2, qsq >= betamin ** 2) # Convert logical to integer return np.where(detector, 1, 0)
def make_propagators(gridshape, gridsize, eV, subslices=[1.0], tilt=[0, 0], tilt_units='mrad', bandwidth_limit=0.6666666666666666)
-
Make the Fresnel freespace propagators for a multislice simulation.
Parameters
gridshape
:(2,) array_like
- Pixel dimensions of the 2D grid
gridsize
:(3,) array_like
- Size of the grid in real space (first two dimensions) and thickness of the object (third dimension)
eV
:float
- Probe energy in electron volts
subslices
:array_like
, optional- A one dimensional array-like object containing the depths (in fractional coordinates) at which the object will be subsliced. The last entry should always be 1.0. For example, to slice the object into four equal sized slices pass [0.25,0.5,0.75,1.0]
tilt
:array_like
, optional- Allows the user to simulate a (small < 50 mrad) tilt of the specimen, by shearing the propagator. Units given by input variable tilt_units.
tilt_units
:string
, optional- Units of specimen tilt, can be 'mrad','pixels' or 'invA'
Returns
P
:(n,Y,X)
- Fresnel free-space progators, the first dimension will be of size
len(
gridsize
)
Expand source code
def make_propagators( gridshape, gridsize, eV, subslices=[1.0], tilt=[0, 0], tilt_units="mrad", bandwidth_limit=2 / 3, ): """ Make the Fresnel freespace propagators for a multislice simulation. Parameters ---------- gridshape : (2,) array_like Pixel dimensions of the 2D grid gridsize : (3,) array_like Size of the grid in real space (first two dimensions) and thickness of the object (third dimension) eV : float Probe energy in electron volts subslices : array_like, optional A one dimensional array-like object containing the depths (in fractional coordinates) at which the object will be subsliced. The last entry should always be 1.0. For example, to slice the object into four equal sized slices pass [0.25,0.5,0.75,1.0] tilt : array_like, optional Allows the user to simulate a (small < 50 mrad) tilt of the specimen, by shearing the propagator. Units given by input variable tilt_units. tilt_units : string, optional Units of specimen tilt, can be 'mrad','pixels' or 'invA' Returns ------- P : (n,Y,X) Fresnel free-space progators, the first dimension will be of size len(`gridsize`) """ from .Probe import make_contrast_transfer_function # We will use the make_contrast_transfer_function function to generate # the propagator, the aperture of this propagator will go out to the maximum # possible and the function bandwidth_limit_array will provide the band # width_limiting gridmax = np.asarray(gridshape) / np.asarray(gridsize[:2]) / 2 app = np.hypot(*gridmax) # Intitialize array prop = np.zeros((len(subslices), *gridshape), dtype=np.complex) for islice, s_ in enumerate(subslices): if islice == 0: deltaz = s_ * gridsize[2] else: deltaz = (s_ - subslices[islice - 1]) * gridsize[2] # Calculate propagator prop[islice, :, :] = bandwidth_limit_array( make_contrast_transfer_function( gridshape, gridsize[:2], eV, app, df=deltaz, app_units="invA", optic_axis=tilt, tilt_units=tilt_units, ), limit=bandwidth_limit, ) return prop
def max_grid_resolution(gridshape, rsize, bandwidthlimit=0.6666666666666666, eV=None)
-
For a given pixel sampling, return maximum multislice grid resolution.
For a given grid pixel size (gridshape) and real space size (rsize) return maximum resolution permitted by the multislice grid. If the probe accelerating voltage is passed in as eV resolution will be given in units of mrad, otherwise resolution will be given in units of inverse Angstrom.
Expand source code
def max_grid_resolution(gridshape, rsize, bandwidthlimit=2 / 3, eV=None): """ For a given pixel sampling, return maximum multislice grid resolution. For a given grid pixel size (gridshape) and real space size (rsize) return maximum resolution permitted by the multislice grid. If the probe accelerating voltage is passed in as eV resolution will be given in units of mrad, otherwise resolution will be given in units of inverse Angstrom. """ max_res = min([gridshape[x] / rsize[x] / 2 * bandwidthlimit for x in range(2)]) if eV is None: return max_res return max_res / wavev(eV) * 1e3
def multislice(probes, nslices, propagators, transmission_functions, tiling=[1, 1], device_type=None, seed=None, return_numpy=True, qspace_in=False, qspace_out=False, posn=None, subslicing=False, output_to_bandwidth_limit=True, reverse=False, transpose=False)
-
Multislice algorithm for scattering of an electron probe.
Parameters
probes
:(n,Y,X)
or(Y,X) complex array_like
- Electron probe wave function(s)
nslices
:int, array_like
- The number of slices (iterations) to perform multislice over, if an
propagators
:(Z,Y,X,2)
or(Y,X,2) torch.array
- Fresnel free space operators required for the multislice algorithm used to propagate the scattering matrix
transmission_functions
:(Z,nT,Y,X,2)
- The transmission functions describing the electron's interaction with the specimen for the multislice algorithm
tiling
:(2,) array_like
- Tiling of a repeat unit cell on simulation grid.
device_type
:torch.device
, optional- torch.device object which will determine which device (CPU or GPU) the calculations will run on
seed
:int
, optional- Seed for the random number generator for frozen phonon configurations
return_numpy
:bool
, optional- Calculations are performed on pytorch tensors for speed, however numpy arrays are more convenient for processing. This input allows the user to control how the output is returned
qspace_in
:bool
, optional- Should be set to True if the input wavefunction is in momentum (q) space and False otherwise (this is the default)
qspace_out
:bool
, optional- Should be set to True if the output wavefunction is desired in momentum (q) space and False otherwise (this is the default)
posn
:None
, optional- Does nothing, included to match calling signature for STEM function
subslicing
:bool
, optional- Pass subslicing=True to access propagation to sub-slices of the unit cell, in this case nslices is taken to be in units of subslices to propagate rather than unit cells (i.e. nslices = 3 will propagate 1.5 unit cells for a subslicing of 2 subslices per unit cell)
output_to_bandwidth_limit
:bool
, optional- Bandwidth-limiting of the arrays is used in multislice to stop wrap-around error in reciprocal space, therefore the output of the multislice algorithm will be zero beyond some point in reciprocal space if this is set to True then these array entries will be cropped out. This does have the effect of the output of the function being on a different sized grid to the input.
reverse
:bool
, optional- Run inverse multislice (for back propagation of a wavefunction)
transpose
:bool
, optional- Reverse the order of the multislice operations, ie. apply propagator first and then transmission function
Returns
psi
:(Y,X)
or(n,Y,X) complex torch.tensor
ornp.ndarray
- Exit surface wave functions as a pytorch tensor or numpy array (default)
depending on whether return_numpy is True or False. If the input
probes
is two dimensional then n = 1
Expand source code
def multislice( probes, nslices, propagators, transmission_functions, tiling=[1, 1], device_type=None, seed=None, return_numpy=True, qspace_in=False, qspace_out=False, posn=None, subslicing=False, output_to_bandwidth_limit=True, reverse=False, transpose=False, ): """ Multislice algorithm for scattering of an electron probe. Parameters ---------- probes : (n,Y,X) or (Y,X) complex array_like Electron probe wave function(s) nslices : int, array_like The number of slices (iterations) to perform multislice over, if an propagators : (Z,Y,X,2) or (Y,X,2) torch.array Fresnel free space operators required for the multislice algorithm used to propagate the scattering matrix transmission_functions : (Z,nT,Y,X,2) The transmission functions describing the electron's interaction with the specimen for the multislice algorithm tiling : (2,) array_like Tiling of a repeat unit cell on simulation grid. device_type : torch.device, optional torch.device object which will determine which device (CPU or GPU) the calculations will run on seed : int, optional Seed for the random number generator for frozen phonon configurations return_numpy : bool, optional Calculations are performed on pytorch tensors for speed, however numpy arrays are more convenient for processing. This input allows the user to control how the output is returned qspace_in : bool, optional Should be set to True if the input wavefunction is in momentum (q) space and False otherwise (this is the default) qspace_out : bool, optional Should be set to True if the output wavefunction is desired in momentum (q) space and False otherwise (this is the default) posn : None, optional Does nothing, included to match calling signature for STEM function subslicing : bool, optional Pass subslicing=True to access propagation to sub-slices of the unit cell, in this case nslices is taken to be in units of subslices to propagate rather than unit cells (i.e. nslices = 3 will propagate 1.5 unit cells for a subslicing of 2 subslices per unit cell) output_to_bandwidth_limit : bool, optional Bandwidth-limiting of the arrays is used in multislice to stop wrap-around error in reciprocal space, therefore the output of the multislice algorithm will be zero beyond some point in reciprocal space if this is set to True then these array entries will be cropped out. This does have the effect of the output of the function being on a different sized grid to the input. reverse : bool, optional Run inverse multislice (for back propagation of a wavefunction) transpose : bool, optional Reverse the order of the multislice operations, ie. apply propagator first and then transmission function Returns ------- psi : (Y,X) or (n,Y,X) complex torch.tensor or np.ndarray Exit surface wave functions as a pytorch tensor or numpy array (default) depending on whether return_numpy is True or False. If the input `probes` is two dimensional then n = 1 """ # If a single integer is passed to the routine then Seed random number generator, # , if None then np.random.RandomState will use the system clock as a seed seed_provided = not (seed is None) if not seed_provided: r = np.random.RandomState() np.random.seed(seed) # Initialize device cuda if available, CPU if no cuda is available device = get_device(device_type) # Since pytorch doesn't have a complex data type we need to add an extra # dimension of size 2 to each tensor that will store real and imaginary # components. T = ensure_torch_array(transmission_functions, device=device) P = ensure_torch_array(propagators, dtype=T.dtype, device=device) psi = ensure_torch_array(probes, dtype=T.dtype, device=device) nT, nsubslices, nopiy, nopix = T.shape[:4] # Probe needs to start multislice algorithm in real space if qspace_in: psi = torch.ifft(psi, signal_ndim=2) slices = generate_slice_indices(nslices, nsubslices, subslicing) for i, islice in enumerate(slices): # If an array-like object is passed then this will be used to uniquely # and reproducibly seed each iteration of the multislice algorithm if seed_provided: r = np.random.RandomState(seed[islice]) subslice = islice % nsubslices # Pick random phase grating it = r.randint(0, nT) # To save memory in the case of equally sliced sample, there is the option # of only using one propagator, this statement catches this case. if P.dim() < 4: P_ = P else: P_ = P[subslice] # If the transmission function is from a tiled unit cell then # there is the option of randomly shifting it around to # generate "more" psuedo-random transmission functions if tiling[0] == 1 & tiling[1] == 1: T_ = T[it, subslice] elif nopiy % tiling[0] == 0 and nopix % tiling[1] == 0: T_ = T[it, subslice] if tiling[0] > 1: # Shift an integer number of pixels in y T_ = torch.roll(T_, r.randint(0, tiling[0]) * (nopiy // tiling[0]), 0) if tiling[1] > 1: # Shift an integer number of pixels in x T_ = torch.roll(T_, r.randint(1, tiling[1]) * (nopix // tiling[1]), 1) else: # Case of a non-integer pixel shifting of the unit cell yshift = r.randint(0, tiling[0]) * (nopiy / tiling[0]) xshift = r.randint(0, tiling[1]) * (nopix / tiling[1]) shift = torch.tensor([yshift, xshift]) # Generate an array to perform Fourier shift of transmission # function FFT_shift_array = fourier_shift_array( [nopiy, nopix], shift, dtype=T.dtype, device=T.device ) # Apply Fourier shift theorem for sub-pixel shift T_ = torch.ifft( complex_mul(FFT_shift_array, torch.fft(T[it, subslice], signal_ndim=2)), signal_ndim=2, ) # Perform multislice iteration if transpose or reverse: # Reverse multislice complex conjugates the transmission and # propagation. Both reverse and transpose multislice reverse # the order of the transmission and conjugation operations # probe should start in real space and finish this iteration in # real space psi = complex_mul( torch.ifft( complex_mul(torch.fft(psi, signal_ndim=2), P_, reverse), signal_ndim=2, ), T_, reverse, ) else: # Standard multislice iteration - probe should start in real space # and finish this iteration in reciprocal space psi = complex_mul(torch.fft(complex_mul(psi, T_), signal_ndim=2), P_) # The probe can be cropped to the bandwidth limit, this removes # superfluous array entries in reciprocal space that are zero # Since the next inverse FFT will apply a factor equal to the # square root number of pixels we have to adjust the values # of the array to compensate if i == len(slices) - 1: lim = 2 / 3 if output_to_bandwidth_limit else 1 psi = crop_to_bandwidth_limit_torch( psi, qspace_in=not (transpose or reverse), qspace_out=qspace_out, limit=lim, norm="conserve_norm", ) elif not (transpose or reverse): # Inverse Fourier transform back to real space for next iteration psi = torch.ifft(psi, signal_ndim=2) if len(slices) < 1 and qspace_out: psi = torch.fft(psi, signal_ndim=2) if return_numpy: return cx_to_numpy(psi) return psi
def nyquist_sampling(rsize=None, resolution_limit=None, eV=None, alpha=None)
-
Calculate nyquist sampling (typically for minimum sampling of a STEM probe).
If array size in units of length is passed then return how many probe positions are required otherwise just return the sampling. Alternatively pass probe accelerating voltage (eV) in electron-volts and probe forming aperture (alpha) in mrad and the resolution limit in inverse length will be calculated for you.
Expand source code
def nyquist_sampling(rsize=None, resolution_limit=None, eV=None, alpha=None): """ Calculate nyquist sampling (typically for minimum sampling of a STEM probe). If array size in units of length is passed then return how many probe positions are required otherwise just return the sampling. Alternatively pass probe accelerating voltage (eV) in electron-volts and probe forming aperture (alpha) in mrad and the resolution limit in inverse length will be calculated for you. """ if eV is None and alpha is None: step_size = 1 / (4 * resolution_limit) elif resolution_limit is None: step_size = 1 / (4 * wavev(eV) * alpha * 1e-3) else: return None if rsize is None: return step_size else: return np.ceil(rsize / step_size).astype(np.int)
def phase_from_com(com, reg=1e-10, rsize=[1, 1])
-
Integrate 4D-STEM centre of mass (DPC) measurements to calculate object phase.
Assumes a three dimensional array com, with the final two dimensions corresponding to the image and the first dimension of the array corresponding to the y and x centre of mass respectively.
Expand source code
def phase_from_com(com, reg=1e-10, rsize=[1, 1]): """ Integrate 4D-STEM centre of mass (DPC) measurements to calculate object phase. Assumes a three dimensional array com, with the final two dimensions corresponding to the image and the first dimension of the array corresponding to the y and x centre of mass respectively. """ # Get shape of arrays ny, nx = com.shape[1:] s = (ny, nx) d = np.asarray(rsize) / np.asarray([ny, nx]) # Calculate Fourier coordinates for array ky = np.fft.fftfreq(ny, d=d[0]) kx = np.fft.rfftfreq(nx, d=d[1]) # Calculate numerator and denominator expressions for solution of # phase from centre of mass measurements numerator = ky[:, None] * np.fft.rfft2(com[0], s=s) + kx[None, :] * np.fft.rfft2( com[1], s=s ) denominator = 1j * ((kx ** 2)[None, :] + (ky ** 2)[:, None]) + reg # Avoid a divide by zero for the origin of the Fourier coordinates numerator[0, 0] = 0 denominator[0, 0] = 1 # Return real part of the inverse Fourier transform return np.fft.irfft2(numerator / denominator, s=s)
def second_moment(array)
-
Calculate the second moment of 2D array as a fraction of array size.
Expand source code
def second_moment(array): """Calculate the second moment of 2D array as a fraction of array size.""" grids = [np.fft.fftfreq(x) for x in array.shape] mass = np.sum(array) first_moment = [ np.sum(x) / mass for x in [grids[0][:, None] * array, grids[1][None, :] * array] ] y2 = ((grids[0] - first_moment[0] + 0.5) % 1.0 - 0.5) ** 2 x2 = ((grids[1] - first_moment[1] + 0.5) % 1.0 - 0.5) ** 2 grid = y2[:, None] + x2[None, :] return np.sqrt(np.sum(grid * array) / mass)
def thickness_to_slices(thicknesses, slice_thickness, subslicing=False, subslices=[1.0])
-
Convert thickness in Angstroms to number of multislice slices.
Expand source code
def thickness_to_slices( thicknesses, slice_thickness, subslicing=False, subslices=[1.0] ): """Convert thickness in Angstroms to number of multislice slices.""" t = np.asarray(ensure_array(thicknesses)) if subslicing: # Work out how many slice of the structure is closest to the desired # output thicknesses m = len(subslices) nslices = (t // slice_thickness).astype(np.int) * m from scipy.spatial.distance import cdist # Work out which subslices of the structure remainder = (t % slice_thickness) / slice_thickness n = len(remainder) dist = cdist( remainder.reshape((n, 1)), np.concatenate(([0], subslices[:-1])).reshape((m, 1)), ) z = [0] + (nslices + np.asarray([i for i in np.argmin(dist, axis=1)])).tolist() return [np.arange(z[i], z[i + 1]) for i in range(len(z) - 1)] else: return np.ceil(t / slice_thickness).astype(np.int)
def tqdm_handler(showProgress)
-
Handle showProgress boolean or string input for the tqdm progress bar.
Expand source code
def tqdm_handler(showProgress): """Handle showProgress boolean or string input for the tqdm progress bar.""" if isinstance(showProgress, str): if showProgress.lower() == "notebook": from tqdm import tqdm_notebook as tqdm tdisable = False elif isinstance(showProgress, bool): tdisable = not showProgress from tqdm import tqdm return tdisable, tqdm
def unit_cell_shift(array, axis, shift, tiles)
-
Shift an array an integer number of unit cell.
For an array consisting of a number of repeat units given by tiles shift than array an integer number of unit cells.
Expand source code
def unit_cell_shift(array, axis, shift, tiles): """ Shift an array an integer number of unit cell. For an array consisting of a number of repeat units given by tiles shift than array an integer number of unit cells. """ indices = torch.remainder(torch.arange(array.shape[-3 + axis]) - shift) if axis == 0: return array[indices, :, :] if axis == 1: return array[:, indices, :]
def workout_4DSTEM_datacube_DP_size(FourD_STEM, rsize, gridshape)
-
Calculate 4D-STEM datacube diffraction pattern gridsize and resampling function.
Parameters
fourD_STEM
:bool
orarray_like
- Pass fourD_STEM = True gives 4D STEM output with native simulation grid sampling. Alternatively, to save disk space a tuple containing pixel size and diffraction space extent of the datacube can be passed in. For example ([64,64],[1.2,1.2]) will output diffraction patterns measuring 64 x 64 pixels and 1.2 x 1.2 inverse Angstroms.
rsize
:(2,) array_like
- Real space size of simulation grid
gridshape
:(2,) array_like
- Pixel size of gridshape
Returns
gridout
:(2,) array_like
- Pixel size of the diffraciton pattern output
resize
:function
- A function that takes diffraction patterns from the simulation and resamples and crops them to the requested size.
Expand source code
def workout_4DSTEM_datacube_DP_size(FourD_STEM, rsize, gridshape): """ Calculate 4D-STEM datacube diffraction pattern gridsize and resampling function. Parameters ---------- fourD_STEM : bool or array_like Pass fourD_STEM = True gives 4D STEM output with native simulation grid sampling. Alternatively, to save disk space a tuple containing pixel size and diffraction space extent of the datacube can be passed in. For example ([64,64],[1.2,1.2]) will output diffraction patterns measuring 64 x 64 pixels and 1.2 x 1.2 inverse Angstroms. rsize : (2,) array_like Real space size of simulation grid gridshape : (2,) array_like Pixel size of gridshape Returns ------- gridout : (2,) array_like Pixel size of the diffraciton pattern output resize : function A function that takes diffraction patterns from the simulation and resamples and crops them to the requested size. """ # Check whether a resampling directive has been given if isinstance(FourD_STEM, (list, tuple)): gridout = FourD_STEM[0] if len(FourD_STEM) > 1: # Get output grid and diffraction space size of that grid from tuple Ksize = FourD_STEM[1] # diff_pat_crop = np.round(np.asarray(Ksize) * np.asarray(rsize[:2])).astype( np.int ) # Define resampling function to crop and interpolate # diffraction patterns def resize(array): cropped = crop(np.fft.fftshift(array, axes=(-1, -2)), diff_pat_crop) return fourier_interpolate_2d(cropped, gridout, norm="conserve_L1") else: # The size in inverse Angstrom of the grid Ksize = np.asarray(gridout) / np.asarray(rsize) # Define resampling function to just crop diffraction # patterns def resize(array): return crop(np.fft.fftshift(array, axes=(-1, -2)), gridout) else: # If no resampling then the output size is just the simulation # grid size gridout = size_of_bandwidth_limited_array(gridshape) # The size in inverse Angstrom of the grid Ksize = np.asarray(gridout) / np.asarray(rsize) # Define a resampling function that does nothing def resize(array): return crop(np.fft.fftshift(array, axes=(-1, -2)), gridout) return gridout, resize, Ksize
Classes
class scattering_matrix (rsize, propagators, transmission_functions, nslice, eV, alpha, GPU_streaming=False, batch_size=30, device=None, PRISM_factor=[1, 1], tiling=[1, 1], device_type=None, seed=None, showProgress=True, bandwidth_limit=0.6666666666666666, Fourier_space_output=False, subslicing=False, transposed=False, stored_gridshape=None)
-
Scattering matrix object for calculations using the PRISM algorithm.
Initialize with a set of propagators and transmission functions.
Parameters
rsize
:(2,) array_like
- Real space size of the simulation grid in Angstrom
propagators
:(N,Y,X,2) torch.array
- Fresnel free space operators required for the multislice algorithm used to propagate the scattering matrix
transmission_functions
:(N,Y,X,2)
- The transmission functions describing the electron's interaction with the specimen for the multislice algorithm
nslice
:int
- The number of slices of the specimen to propagate the scattering matrix to
eV
:float
- Electron probe energy in electron-volts
alpha
:float
- Maximum input angle for the scattering matrix, should match the probe forming aperture used in experiment
GPU_streaming
:bool
, optional- If True, the scattering matrix will be stored off GPU RAM and streamed to GPU RAM as necessary, does nothing if the calculation is CPU only
batch_size
:int
, optional- The multislice algorithm can be performed on multiple scattering matrix columns at once to parrallelize computation, this number is set by batch_size.
device
:torch.device
, optional- torch.device object which will determine which device (CPU or GPU) the calculations will run on. By default this will be determined by what device the transmission functions are stored on.
PRISM_factor
:int (2,) array_like
- The PRISM "interpolation factor" this is the amount by which the scattering matrices are cropped in real space to speed up calculations see Ophus, Colin. "A fast image simulation algorithm for scanning transmission electron microscopy." Advanced structural and chemical imaging 3.1 (2017): 13 for details on this.
seed
:int32
, optional- A seed to control seeding of the frozen phonon approximation
showProgress
:str
orbool
, optional- Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook
bandwidth_limit
:float
, optional- Band-width limiting of the transmission function and propagators to prevent wrap-around error in the multislice algorithm, 2/3 by default
Fourier_space_output
:bool
, optional- If True the scattering matrix output will be stored in reciprocal space, default is False
subslicing
:bool
, optional- Pass subslicing=True to access propagation to sub-slices of the unit cell, in this case nslices is taken to be in units of subslices to propagate rather than unit cells (i.e. nslices = 3 will propagate 1.5 unit cells for a subslicing of 2 subslices per unit cell)
transposed
:bool
, optional- Make a "transposed" scattering matrix - see Brown et al. (2019) Physical Review Research paper for a discussion of this and its applications
stored_gridshape
:(2,) array_like
- Size of the stored grid, can be chosen to be smaller than the multislice grid to speed up computation of a smaller diffraction space view than that implied by the multislice at no cost to computational accuracy.
Expand source code
class scattering_matrix: """Scattering matrix object for calculations using the PRISM algorithm.""" def __init__( self, rsize, propagators, transmission_functions, nslice, eV, alpha, GPU_streaming=False, batch_size=30, device=None, PRISM_factor=[1, 1], tiling=[1, 1], device_type=None, seed=None, showProgress=True, bandwidth_limit=2 / 3, Fourier_space_output=False, subslicing=False, transposed=False, stored_gridshape=None, ): """ Initialize with a set of propagators and transmission functions. Parameters ---------- rsize : (2,) array_like Real space size of the simulation grid in Angstrom propagators : (N,Y,X,2) torch.array Fresnel free space operators required for the multislice algorithm used to propagate the scattering matrix transmission_functions : (N,Y,X,2) The transmission functions describing the electron's interaction with the specimen for the multislice algorithm nslice : int The number of slices of the specimen to propagate the scattering matrix to eV : float Electron probe energy in electron-volts alpha : float Maximum input angle for the scattering matrix, should match the probe forming aperture used in experiment GPU_streaming : bool, optional If True, the scattering matrix will be stored off GPU RAM and streamed to GPU RAM as necessary, does nothing if the calculation is CPU only batch_size : int, optional The multislice algorithm can be performed on multiple scattering matrix columns at once to parrallelize computation, this number is set by batch_size. device : torch.device, optional torch.device object which will determine which device (CPU or GPU) the calculations will run on. By default this will be determined by what device the transmission functions are stored on. PRISM_factor : int (2,) array_like The PRISM "interpolation factor" this is the amount by which the scattering matrices are cropped in real space to speed up calculations see Ophus, Colin. "A fast image simulation algorithm for scanning transmission electron microscopy." Advanced structural and chemical imaging 3.1 (2017): 13 for details on this. seed : int32, optional A seed to control seeding of the frozen phonon approximation showProgress : str or bool, optional Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook bandwidth_limit : float, optional Band-width limiting of the transmission function and propagators to prevent wrap-around error in the multislice algorithm, 2/3 by default Fourier_space_output : bool, optional If True the scattering matrix output will be stored in reciprocal space, default is False subslicing : bool, optional Pass subslicing=True to access propagation to sub-slices of the unit cell, in this case nslices is taken to be in units of subslices to propagate rather than unit cells (i.e. nslices = 3 will propagate 1.5 unit cells for a subslicing of 2 subslices per unit cell) transposed : bool, optional Make a "transposed" scattering matrix - see Brown et al. (2019) Physical Review Research paper for a discussion of this and its applications stored_gridshape : (2,) array_like Size of the stored grid, can be chosen to be smaller than the multislice grid to speed up computation of a smaller diffraction space view than that implied by the multislice at no cost to computational accuracy. """ # Get size of grid gridshape = transmission_functions.shape[-3:-1] # Datatype (precision) is inferred from transmission functions self.dtype = transmission_functions.dtype # Device (CPU or GPU) is also inferred from transmission functions self.device = device if GPU_streaming: self.device = torch.device("cpu") elif self.device is None: self.device = transmission_functions.device # Get alpha in units of inverse Angstrom self.alpha_ = wavev(eV) * alpha * 1e-3 self.PRISM_factor = PRISM_factor self.doPRISM = np.any(np.asarray(PRISM_factor) > 1) # Make a list of beams in the scattering matrix # Take beams inside the aperture and every nth beam where n is the # PRISM "interpolation" factor q = q_space_array(gridshape, rsize) inside_aperture = np.less_equal(q[0] ** 2 + q[1] ** 2, self.alpha_ ** 2) mody, modx = [ np.mod(np.fft.fftfreq(x, 1 / x).astype(np.int), p) == 0 for x, p in zip(gridshape, self.PRISM_factor) ] self.beams = np.nonzero( np.logical_and( np.logical_and(inside_aperture, mody[:, None]), modx[None, :] ) ) self.beams = [(x + y // 2) % y - y // 2 for x, y in zip(self.beams, gridshape)] self.nbeams = len(self.beams[0]) # For a scattering matrix stored in real space there is the option # of storing it on a much smaller pixel grid than the grid used for # multislice. This is handy when, for example, a large grid # is required for a converged multislice calculation but only # the bright-field region of diffraction (small angle region) # is of interest. Be careful using this in conjunction with # multiple calls of the propagation method for the scattering matrix, # as information outside the angular range of the stored grid is lost. self.crop_output = not (stored_gridshape is None) if self.crop_output: self.stored_gridshape = stored_gridshape # We will only store output of the scattering matrix up to the band # width limit of the calculation, since this is a circular band-width # limit on a square grid we have to get somewhat fancy and store a mapping # of the pixels within the bandwidth limit to a one-dimensional vector self.bw_mapping = np.argwhere( np.logical_and( ( np.abs(np.fft.fftfreq(gridshape[0], d=1 / gridshape[0])) < self.stored_gridshape[0] // 2 )[:, np.newaxis], ( np.abs(np.fft.fftfreq(gridshape[1], d=1 / gridshape[1])) < self.stored_gridshape[1] // 2 )[np.newaxis, :], ) ) else: self.stored_gridshape = size_of_bandwidth_limited_array( transmission_functions.shape[-3:-1] ) # We will only store output of the scattering matrix up to the band # width limit of the calculation, since this is a circular band-width # limit on a square grid we have to get somewhat fancy and store a mapping # of the pixels within the bandwidth limit to a one-dimensional vector self.bw_mapping = np.argwhere( (np.fft.fftfreq(gridshape[0]) ** 2)[:, np.newaxis] + (np.fft.fftfreq(gridshape[1]) ** 2)[np.newaxis, :] < (bandwidth_limit / 2) ** 2 ) self.nbout = self.bw_mapping.shape[0] self.gridshape, self.rsize, self.eV = [np.asarray(gridshape), rsize, eV] self.bw_mapping = ( self.bw_mapping + self.gridshape // 2 ) % self.gridshape - self.gridshape // 2 self.PRISM_factor, self.tiling = [PRISM_factor, tiling] self.doPRISM = np.any([self.PRISM_factor[i] > 1 for i in [0, 1]]) self.Fourier_space_output = Fourier_space_output self.nsubslices = transmission_functions.shape[1] slices = generate_slice_indices(nslice, self.nsubslices, subslicing=subslicing) self.GPU_streaming = GPU_streaming self.transposed = transposed self.seed = seed if self.seed is None: # If no seed passed to random number generator then make one to pass to # the multislice algorithm. This ensure that each column in the scattering # matrix sees the same frozen phonon configuration self.seed = np.random.randint( 0, 2 ** 31 - 1, size=len(slices), dtype=np.uint32 ) # This switch tells the propagate function to initialize the Smatrix # to plane waves self.initialized = False # Propagate wave functions of scattering matrix self.current_slice = 0 self.show_Progress = showProgress self.Propagate( nslice, propagators, transmission_functions, subslicing=subslicing, showProgress=self.show_Progress, batch_size=batch_size, ) def Propagate( self, nslice, propagators, transmission_functions, subslicing=False, batch_size=3, showProgress=True, transpose=False, ): """ Propagate a scattering matrix to slice nslice of the specimen. Parameters ---------- nslice : int The slice in the specimen to propagate the scattering matrix to propagators : (N,Y,X,2) torch.array Fresnel free space operators required for the multislice algorithm used to propagate the scattering matrix transmission_functions : (N,Y,X,2) The transmission functions describing the electron's interaction with the specimen for the multislice algorithm batch_size : int, optional The multislice algorithm can be performed on multiple scattering matrix columns at once to parrallelize computation, this number is set by batch_size. subslicing : bool, optional Pass subslicing=True to access propagation to sub-slices of the unit cell, in this case nslices is taken to be in units of subslices to propagate rather than unit cells (i.e. nslices = 3 will propagate 1.5 unit cells for a subslicing of 2 subslices per unit cell) showProgress : str or bool, optional Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook transpose : bool, optional Make a "transposed" scattering matrix - see Brown et al. (2019) Physical Review Research paper for a discussion of this and its applications """ tdisable, tqdm = tqdm_handler(showProgress) from .Probe import plane_wave_illumination # Initialize scattering matrix if necessary if not self.initialized: if self.Fourier_space_output: self.S = torch.zeros( self.nbeams, self.nbout, 2, dtype=self.dtype, device=self.device ) else: self.S = torch.zeros( self.nbeams, *self.stored_gridshape, 2, dtype=self.dtype, device=self.device ) for ibeam in range(self.nbeams): # Initialize S-matrix to plane-waves psi = cx_from_numpy( plane_wave_illumination( self.gridshape, self.rsize[:2], self.eV, tilt=[self.beams[0][ibeam], self.beams[1][ibeam]], tilt_units="pixels", qspace=True, ) ) # Adjust intensity for correct normalization of S matrix rows # taking into account the PRISM factor that needs to be applied # when the Smatrix is evaluated (only 1/product(PRISM_factor) # beams are taken and only 1/product(PRISM_factor) intensity # is cropped out in real space) psi *= torch.prod(torch.tensor(self.PRISM_factor, dtype=self.dtype)) if self.Fourier_space_output: self.S[ibeam] = psi[self.bw_mapping[:, 0], self.bw_mapping[:, 1], :] else: self.S[ibeam] = fourier_interpolate_2d_torch( psi, self.stored_gridshape, qspace_in=True, qspace_out=False, norm="conserve_norm", ) self.initialized = True # Make nslice_ which always accounts of subslices of the structure if subslicing: nslice_ = nslice else: nslice_ = nslice * self.nsubslices # Work out direction of propagation through specimen if nslice_ != self.current_slice: direction = np.sign(nslice_ - self.current_slice) else: direction = 1 if direction == 0: direction = 1 if nslice_ > len(self.seed): # Add new seeds to determine random translations for frozen-phonon # multislice (required for reversability of multislice) if required self.seed = np.concatenate( [ self.seed, np.random.randint(0, 2 ** 31 - 1, size=nslice_ - len(self.seed)), ] ) # Now generate list of slices that the multislice algorithm will run through slices = np.arange(self.current_slice, nslice_, direction) if direction < 0: slices += direction # For a transposed scattering matrix the order of the slices # in multislice should be reversed if self.transposed: slices = slices[::-1] # If streaming of Smatrix columns to the GPU is being used, ensure # that propagators and transmission functions for the multislice are # already on the GPU if self.GPU_streaming: propagators = ensure_torch_array(propagators).cuda() transmission_functions = ensure_torch_array(transmission_functions).cuda() self.current_slice = nslice_ if len(slices) < 1: return # Loop over the different plane wave components (or columns) of the # scattering matrix for i in tqdm( range(int(np.ceil(self.nbeams / batch_size))), disable=tdisable, desc="Calculating S-matrix", ): # Initialize array that will be used as input to the multislice routine psi = torch.zeros( batch_size, *self.gridshape, 2, dtype=self.dtype, device=self.device ) beams = np.arange( i * batch_size, min((i + 1) * batch_size, self.nbeams), dtype=np.int ) if self.Fourier_space_output: # Expand S-matrix input to full grid for multislice propagation psi[ : beams.shape[0], self.bw_mapping[:, 0], self.bw_mapping[:, 1], : ] = self.S[beams] else: # Fourier interpolate stored real space S-matrix column onto # multislice grid psi = fourier_interpolate_2d_torch( self.S[beams], self.gridshape, norm="conserve_norm" ) if self.GPU_streaming: psi = ensure_torch_array(psi, dtype=self.dtype).to("cuda") output = multislice( psi[: beams.shape[0]], slices, propagators, transmission_functions, self.tiling, self.device, self.seed, return_numpy=False, qspace_in=self.Fourier_space_output, qspace_out=self.Fourier_space_output, transpose=self.transposed, output_to_bandwidth_limit=False, reverse=direction < 0, ) if self.GPU_streaming: output = output.to(self.device) if self.Fourier_space_output: self.S[beams] = output[ :, self.bw_mapping[:, 0], self.bw_mapping[:, 1], : ] * np.sqrt(np.prod(self.stored_gridshape) / np.prod(self.gridshape)) else: output = fourier_interpolate_2d_torch( output, self.stored_gridshape, norm="conserve_norm" ) self.S[beams] = output def PRISM_crop_window(self, win=None, device=None): """Calculate 2D array indices of STEM crop window.""" device = get_device(device) if win is None: win = self.PRISM_factor crop_ = [ torch.arange( -self.stored_gridshape[i] // (2 * win[i]), self.stored_gridshape[i] // (2 * win[i]), device=device, ) for i in range(2) ] return crop_ def __call__(self, probes, nslices, posn=None, Smat=None, scan_transform=None): """ Calculate exit-surface waves function using the scattering matrix. Parameters ---------- probes : (N,Y,X,2) torch.array Input wave functions to calculate exit surface wave functions from must be in Diffraction space nslices : Does nothing, only there to match call signature for STEM routine posn : array_like (N,2) Positions of S : array_like (Nbeams,Y,X,2) Scattering matrix object Returns ------- output : (N,Y,X,2) torch.array Exit surface wave functions """ from copy import deepcopy if Smat is None: Smat = self.S Sshape = [int(x) for x in Smat.shape] device = Smat.device crop_ = self.PRISM_crop_window(device=device) # Ensure posn and probes are pytorch arrays probes = ensure_torch_array(probes, dtype=self.dtype, device=device) # Ensure probes tensors correspond to the shape N x Y x X x 2 # If they have the shape Y x X x 2 then reshape to 1 x Y x X x 2 if probes.ndim < 4: probes = probes.view(1, *probes.shape) # Get number of probes nprobes = probes.shape[0] if posn is None: posn = torch.zeros(nprobes, 2, device=device, dtype=self.dtype) else: posn = torch.as_tensor(posn, device=device, dtype=self.dtype).view( (nprobes, 2) ) if scan_transform is not None: posn = scan_transform(posn) # TODO decide whether to remove the Fourier_space_output option if self.Fourier_space_output: # A note on normalization: an individual probe enters the STEM routine # with sum_squared intensity of 1, but the STEM routine applies an # FFT so the sum_squared intensity is now equal to # pixels # For a correct matrix multiplication we must now divide by sqrt(# pixels) probe_vec = complex_matmul( probes[:, self.beams[0], self.beams[1]], Smat ) / np.sqrt(np.prod(self.gridshape)) # Now reshape output from vectors to square arrays probes = torch.zeros( nprobes, *self.stored_gridshape, 2, dtype=self.dtype, device=self.device ) probes[:, self.bw_mapping[:, 0], self.bw_mapping[:, 1], :] = probe_vec # Apply PRISM cropping in real space if appropriate if self.doPRISM: shape = probes.size() probes = torch.ifft(probes, signal_ndim=2).flatten(-3, -2) for k in range(nprobes): # Calculate windows in vertical and horizontal directions window = crop_window_to_flattened_indices_torch( [ (crop_[i] + posn[k, i] * self.stored_gridshape[i]) % self.stored_gridshape[i] for i in range(2) ], self.stored_gridshape, ) probe = deepcopy(probes[k]) probes[k] = 0 probes[k, window, :] = probe[window, :] probes = probes.reshape(shape) # Transform probe back to Fourier space return torch.fft(probes, signal_ndim=2) return probes else: # Flatten the array dimensions Smatshape = Smat.shape flattened_shape = [Smatshape[0], Smatshape[-3] * Smatshape[-2], 2] N = probes.shape[0] output = torch.zeros( N, *Smat.shape[-3:-1], 2, dtype=self.dtype, device=Smat.device ) # For evaluating the probes in real space we only want to perform the matrix # multiplication and summation within the real space PRISM cropping region stride = [x // y for x, y in zip(self.stored_gridshape, self.PRISM_factor)] halfstride = [x // 2 for x in stride] # for k in range(probes.size(0)): for probe, pos, out in zip(probes, posn, output): if self.doPRISM: start = [ int(torch.round(pos[i] * Sshape[-3 + i])) - halfstride[i] for i in range(2) ] windows = crop_window_to_periodic_indices( [start[0], stride[0], start[1], stride[1]], Sshape[-3:-1] ) for wind in windows: outview = out.narrow(-3, wind[0][0], wind[0][1]).narrow( -2, wind[1][0], wind[1][1] ) sview = Smat.narrow(-3, wind[0][0], wind[0][1]).narrow( -2, wind[1][0], wind[1][1] ) p = probe[self.beams[0], self.beams[1]].view( self.nbeams, 1, 1, 2 ) outview += torch.sum(complex_mul(p, sview), axis=0) else: output += complex_matmul( probe[self.beams[0], self.beams[1]], Smat.view(flattened_shape) ).view(Smatshape[1:]) output /= np.sqrt(np.prod(probes.size()[-3:-1])) output = crop_torch( output.reshape(probes.size(0), *Smat.size()[-3:]), self.stored_gridshape ) return torch.fft(output, signal_ndim=2) def STEM_with_GPU_streaming( self, detectors=None, FourD_STEM=None, datacube=None, STEM_image=None, nstreams=None, df=0, aberrations=[], ROI=[0.0, 0.0, 1.0, 1.0], device=None, scan_posns=None, showProgress=True, ): """ Perform STEM with scattering matrix streamed between RAM and GPU memory. This allows much larger fields of view to be calculated with relatively modest graphics card memory. The STEM raster is segmented into spatially close clusters and the probe positions in these clusters are processed sequentially, with the relevant part of the scattering matrix streamed from CPU to GPU memory. Parameters ---------- self : scattering_matrix The scattering matrix object. detectors : (Ndet, Y, X) array_like, optional Diffraction plane detectors to perform conventional STEM imaging. If None is passed then no conventional STEM images will be returned. fourD_STEM : bool or array_like, optional Pass fourD_STEM = True to perform 4D-STEM simulations. To save disk space a tuple containing pixel size and diffraction space extent of the datacube can be passed in. For example ([64,64],[1.2,1.2]) will output diffraction patterns measuring 64 x 64 pixels and 1.2 x 1.2 inverse Angstroms. datacube : (Ny, Nx, Y, X) array_like, optional datacube for 4D-STEM output, if None is passed (default) this will be initialized in the function. If a datacube is passed then the result will be added by the STEM routine (useful for multiple frozen phonon iterations) STEM_image : (Ndet,Ny,Nx) array_like, optional Array that will contain the conventional STEM images, if not passed will be initialized within the function. If it is passed then the result will be accumulated within the function, which is useful for multiple frozen phonon iterations. nstreams : int, optional Number of streams (seperate transfers from CPU to GPU memory). If None this will just be set to the product of the PRISM interpolation factor df : float, optional Defocus in Angstrom aberrations : list, optional A list containing a set of the class aberration, pass an empty list for an unaberrated contrast transfer function. ROI : (4,) array_like Fraction of the unit cell to be scanned. Should contain [y0,x0,y1,x1] where [x0,y0] and [x1,y1] are the bottom left and top right coordinates of the region of interest (ROI) expressed as a fraction of the total grid (or unit cell). device : torch.device, optional torch.device object which will determine which device (CPU or GPU) the calculations will run on. scan_posn : (...,2) array_like, optional Array containing the STEM scan positions in fractional coordinates. If provided scan_posn.shape[:-1] will give the shape of the STEM image. showProgress : str or bool, optional Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook """ device = get_device(device) tdisable, tqdm = tqdm_handler(showProgress) # Get indices of PRISM cropping window crop_ = [x.cpu().numpy() for x in self.PRISM_crop_window()] # Make the STEM probe probe = focused_probe( self.gridshape, self.rsize[:2], self.eV, self.alpha_, df=df, aberrations=aberrations, app_units="invA", ) probe = cx_from_numpy(probe, device=device, dtype=self.dtype) # Make scan positions if none already provided if scan_posns is None: scan_posns = generate_STEM_raster( self.rsize, self.eV, self.alpha_, tiling=self.tiling, ROI=ROI, invA=True ) # Get scan (and STEM image) array shape and total number of scan positions scan_shape = scan_posns.shape[:-1] nscan = np.product(scan_shape) # Flatten scan positions to simplify iteration later on. scan_posns = scan_posns.reshape((nscan, 2)) # Calculate default 4D-STEM diffraction pattern sampling if FourD_STEM is True: GS = self.stored_gridshape FourD_STEM = [GS, GS / self.rsize[:2]] # Allocate diffraction pattern and STEM images if not already provided if FourD_STEM: gridout = workout_4DSTEM_datacube_DP_size( FourD_STEM, self.rsize, self.gridshape )[0] if (datacube is None) and FourD_STEM: datacube = np.zeros((*scan_shape, *gridout)) if not FourD_STEM: datacube = None # If detectors are provided then we are doing conventional STEM doConventionalSTEM = detectors is not None # Initialize STEM images if not provided if doConventionalSTEM: ndet = detectors.shape[0] if STEM_image is None: STEM_image = np.zeros((ndet, nscan)) else: STEM_image = STEM_image.reshape((ndet, nscan)) else: STEM_image = None if nstreams is None: # If the number of seperate streams is not suggested by the # user, make this equal to the product of the PRISM factor nstreams = int(np.product(self.PRISM_factor)) # Divide up the scan positions into clusters based on Euclidean # distance from sklearn.cluster import Birch if nscan > 1: model = Birch(threshold=0.01, n_clusters=nstreams) yhat = model.fit_predict(scan_posns) clusters = np.unique(yhat) else: yhat, clusters = [[0], [0]] # Now do STEM with each of the scan position clusters, streaming # only the necessary bits of the scattering matrix to the GPU. Datacube_segment = None STEM_image_segment = None FlatS = self.S.reshape((self.nbeams, np.prod(self.stored_gridshape), 2)) # Loop over probe positions clusters. This would be a good candidate # for multi-GPU work. for cluster in tqdm(clusters, desc="Probe position clusters", disable=tdisable): # Get map of probe positions in cluster points = np.nonzero(yhat == cluster)[0] npoints = len(points) # Get segments of images to update if doConventionalSTEM: STEM_image_segment = STEM_image[:, points] if FourD_STEM: Datacube_segment = np.zeros((1, npoints, 1, *gridout)) pix_posn = scan_posns[points] * np.asarray(self.stored_gridshape) # Work out bounds of the rectangular region of the scattering # matrix to stream to the GPU ymin, ymax = [ int(np.floor(np.amin(pix_posn[:, 0]) + crop_[0][0])), int(np.ceil(np.amax(pix_posn[:, 0]) + crop_[0][-1])), ] xmin, xmax = [ int(np.floor(np.amin(pix_posn[:, 1]) + crop_[1][0])), int(np.ceil(np.amax(pix_posn[:, 1]) + crop_[1][-1])), ] size = np.asarray([(ymax - ymin), (xmax - xmin)]) # Get indices of region of scattering matrix to stream to GPU window = [np.arange(a, b) for a, b in zip([ymin, xmin], [ymax, xmax])] indices = crop_window_to_flattened_indices_torch( window, self.stored_gridshape ) # Get segment of the scattering matrix to stream to GPU segmentshape = [len(x) for x in window] SegmentS = FlatS[:, indices, :].reshape((self.nbeams, *segmentshape, 2)) # Define a function that will map probe positions for the global # scattering matrix to their correct place on the smaller scattering # matrix streamed to the GPU. gshape = torch.as_tensor(self.stored_gridshape).to(device).type(self.dtype) Origin = torch.as_tensor([ymin, xmin]).to(device).type(self.dtype) segment_size = torch.as_tensor(size).to(device).type(self.dtype) def scan_transform(posn): return (posn * gshape - Origin) / segment_size # Keyword arguments to be passed to the __call__ function by the # STEM routine kwargs = {"Smat": SegmentS.to(device), "scan_transform": scan_transform} # Calculate STEM images STEM( self.rsize, probe, self.__call__, [1], self.eV, self.alpha_, detectors=detectors, FourD_STEM=FourD_STEM, datacube=Datacube_segment, scan_posn=scan_posns[points].reshape((npoints, 1, 2)), STEM_image=STEM_image_segment, method_kwargs=kwargs, showProgress=False, device=device, ) if doConventionalSTEM: STEM_image[:, points] += STEM_image_segment if FourD_STEM: for point, Dp in zip(points, Datacube_segment[0]): y, x = np.unravel_index(point, scan_shape) datacube[y, x] += Dp[0] # Unflatten 4D-STEM datacube scan dimensions, use numpy squeeze to # remove superfluous dimensions (ones with length 1) # if FourD_STEM: # datacube = datacube.reshape(*scan_shape, *datacube.shape[-2:]) if doConventionalSTEM: STEM_image = np.squeeze(STEM_image.reshape(ndet, *scan_shape)) # Return STEM images and datacube as a dictionary. If either of these # objects were not calculated the dictionary will contain None for those # entries. return {"STEM images": STEM_image, "datacube": datacube}
Methods
def PRISM_crop_window(self, win=None, device=None)
-
Calculate 2D array indices of STEM crop window.
Expand source code
def PRISM_crop_window(self, win=None, device=None): """Calculate 2D array indices of STEM crop window.""" device = get_device(device) if win is None: win = self.PRISM_factor crop_ = [ torch.arange( -self.stored_gridshape[i] // (2 * win[i]), self.stored_gridshape[i] // (2 * win[i]), device=device, ) for i in range(2) ] return crop_
def Propagate(self, nslice, propagators, transmission_functions, subslicing=False, batch_size=3, showProgress=True, transpose=False)
-
Propagate a scattering matrix to slice nslice of the specimen.
Parameters
nslice
:int
- The slice in the specimen to propagate the scattering matrix to
propagators
:(N,Y,X,2) torch.array
- Fresnel free space operators required for the multislice algorithm used to propagate the scattering matrix
transmission_functions
:(N,Y,X,2)
- The transmission functions describing the electron's interaction with the specimen for the multislice algorithm
batch_size
:int
, optional- The multislice algorithm can be performed on multiple scattering matrix columns at once to parrallelize computation, this number is set by batch_size.
subslicing
:bool
, optional- Pass subslicing=True to access propagation to sub-slices of the unit cell, in this case nslices is taken to be in units of subslices to propagate rather than unit cells (i.e. nslices = 3 will propagate 1.5 unit cells for a subslicing of 2 subslices per unit cell)
showProgress
:str
orbool
, optional- Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook
transpose
:bool
, optional- Make a "transposed" scattering matrix - see Brown et al. (2019) Physical Review Research paper for a discussion of this and its applications
Expand source code
def Propagate( self, nslice, propagators, transmission_functions, subslicing=False, batch_size=3, showProgress=True, transpose=False, ): """ Propagate a scattering matrix to slice nslice of the specimen. Parameters ---------- nslice : int The slice in the specimen to propagate the scattering matrix to propagators : (N,Y,X,2) torch.array Fresnel free space operators required for the multislice algorithm used to propagate the scattering matrix transmission_functions : (N,Y,X,2) The transmission functions describing the electron's interaction with the specimen for the multislice algorithm batch_size : int, optional The multislice algorithm can be performed on multiple scattering matrix columns at once to parrallelize computation, this number is set by batch_size. subslicing : bool, optional Pass subslicing=True to access propagation to sub-slices of the unit cell, in this case nslices is taken to be in units of subslices to propagate rather than unit cells (i.e. nslices = 3 will propagate 1.5 unit cells for a subslicing of 2 subslices per unit cell) showProgress : str or bool, optional Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook transpose : bool, optional Make a "transposed" scattering matrix - see Brown et al. (2019) Physical Review Research paper for a discussion of this and its applications """ tdisable, tqdm = tqdm_handler(showProgress) from .Probe import plane_wave_illumination # Initialize scattering matrix if necessary if not self.initialized: if self.Fourier_space_output: self.S = torch.zeros( self.nbeams, self.nbout, 2, dtype=self.dtype, device=self.device ) else: self.S = torch.zeros( self.nbeams, *self.stored_gridshape, 2, dtype=self.dtype, device=self.device ) for ibeam in range(self.nbeams): # Initialize S-matrix to plane-waves psi = cx_from_numpy( plane_wave_illumination( self.gridshape, self.rsize[:2], self.eV, tilt=[self.beams[0][ibeam], self.beams[1][ibeam]], tilt_units="pixels", qspace=True, ) ) # Adjust intensity for correct normalization of S matrix rows # taking into account the PRISM factor that needs to be applied # when the Smatrix is evaluated (only 1/product(PRISM_factor) # beams are taken and only 1/product(PRISM_factor) intensity # is cropped out in real space) psi *= torch.prod(torch.tensor(self.PRISM_factor, dtype=self.dtype)) if self.Fourier_space_output: self.S[ibeam] = psi[self.bw_mapping[:, 0], self.bw_mapping[:, 1], :] else: self.S[ibeam] = fourier_interpolate_2d_torch( psi, self.stored_gridshape, qspace_in=True, qspace_out=False, norm="conserve_norm", ) self.initialized = True # Make nslice_ which always accounts of subslices of the structure if subslicing: nslice_ = nslice else: nslice_ = nslice * self.nsubslices # Work out direction of propagation through specimen if nslice_ != self.current_slice: direction = np.sign(nslice_ - self.current_slice) else: direction = 1 if direction == 0: direction = 1 if nslice_ > len(self.seed): # Add new seeds to determine random translations for frozen-phonon # multislice (required for reversability of multislice) if required self.seed = np.concatenate( [ self.seed, np.random.randint(0, 2 ** 31 - 1, size=nslice_ - len(self.seed)), ] ) # Now generate list of slices that the multislice algorithm will run through slices = np.arange(self.current_slice, nslice_, direction) if direction < 0: slices += direction # For a transposed scattering matrix the order of the slices # in multislice should be reversed if self.transposed: slices = slices[::-1] # If streaming of Smatrix columns to the GPU is being used, ensure # that propagators and transmission functions for the multislice are # already on the GPU if self.GPU_streaming: propagators = ensure_torch_array(propagators).cuda() transmission_functions = ensure_torch_array(transmission_functions).cuda() self.current_slice = nslice_ if len(slices) < 1: return # Loop over the different plane wave components (or columns) of the # scattering matrix for i in tqdm( range(int(np.ceil(self.nbeams / batch_size))), disable=tdisable, desc="Calculating S-matrix", ): # Initialize array that will be used as input to the multislice routine psi = torch.zeros( batch_size, *self.gridshape, 2, dtype=self.dtype, device=self.device ) beams = np.arange( i * batch_size, min((i + 1) * batch_size, self.nbeams), dtype=np.int ) if self.Fourier_space_output: # Expand S-matrix input to full grid for multislice propagation psi[ : beams.shape[0], self.bw_mapping[:, 0], self.bw_mapping[:, 1], : ] = self.S[beams] else: # Fourier interpolate stored real space S-matrix column onto # multislice grid psi = fourier_interpolate_2d_torch( self.S[beams], self.gridshape, norm="conserve_norm" ) if self.GPU_streaming: psi = ensure_torch_array(psi, dtype=self.dtype).to("cuda") output = multislice( psi[: beams.shape[0]], slices, propagators, transmission_functions, self.tiling, self.device, self.seed, return_numpy=False, qspace_in=self.Fourier_space_output, qspace_out=self.Fourier_space_output, transpose=self.transposed, output_to_bandwidth_limit=False, reverse=direction < 0, ) if self.GPU_streaming: output = output.to(self.device) if self.Fourier_space_output: self.S[beams] = output[ :, self.bw_mapping[:, 0], self.bw_mapping[:, 1], : ] * np.sqrt(np.prod(self.stored_gridshape) / np.prod(self.gridshape)) else: output = fourier_interpolate_2d_torch( output, self.stored_gridshape, norm="conserve_norm" ) self.S[beams] = output
def STEM_with_GPU_streaming(self, detectors=None, FourD_STEM=None, datacube=None, STEM_image=None, nstreams=None, df=0, aberrations=[], ROI=[0.0, 0.0, 1.0, 1.0], device=None, scan_posns=None, showProgress=True)
-
Perform STEM with scattering matrix streamed between RAM and GPU memory.
This allows much larger fields of view to be calculated with relatively modest graphics card memory. The STEM raster is segmented into spatially close clusters and the probe positions in these clusters are processed sequentially, with the relevant part of the scattering matrix streamed from CPU to GPU memory.
Parameters
self
:scattering_matrix
- The scattering matrix object.
detectors
:(Ndet, Y, X) array_like
, optional- Diffraction plane detectors to perform conventional STEM imaging. If None is passed then no conventional STEM images will be returned.
fourD_STEM
:bool
orarray_like
, optional- Pass fourD_STEM = True to perform 4D-STEM simulations. To save disk space a tuple containing pixel size and diffraction space extent of the datacube can be passed in. For example ([64,64],[1.2,1.2]) will output diffraction patterns measuring 64 x 64 pixels and 1.2 x 1.2 inverse Angstroms.
datacube
:(Ny, Nx, Y, X) array_like
, optional- datacube for 4D-STEM output, if None is passed (default) this will be initialized in the function. If a datacube is passed then the result will be added by the STEM routine (useful for multiple frozen phonon iterations)
STEM_image
:(Ndet,Ny,Nx) array_like
, optional- Array that will contain the conventional STEM images, if not passed will be initialized within the function. If it is passed then the result will be accumulated within the function, which is useful for multiple frozen phonon iterations.
nstreams
:int
, optional- Number of streams (seperate transfers from CPU to GPU memory). If None this will just be set to the product of the PRISM interpolation factor
df
:float
, optional- Defocus in Angstrom
aberrations
:list
, optional- A list containing a set of the class aberration, pass an empty list for an unaberrated contrast transfer function.
ROI
:(4,) array_like
- Fraction of the unit cell to be scanned. Should contain [y0,x0,y1,x1] where [x0,y0] and [x1,y1] are the bottom left and top right coordinates of the region of interest (ROI) expressed as a fraction of the total grid (or unit cell).
device
:torch.device
, optional- torch.device object which will determine which device (CPU or GPU) the calculations will run on.
scan_posn
:(…,2) array_like
, optional- Array containing the STEM scan positions in fractional coordinates. If provided scan_posn.shape[:-1] will give the shape of the STEM image.
showProgress
:str
orbool
, optional- Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook
Expand source code
def STEM_with_GPU_streaming( self, detectors=None, FourD_STEM=None, datacube=None, STEM_image=None, nstreams=None, df=0, aberrations=[], ROI=[0.0, 0.0, 1.0, 1.0], device=None, scan_posns=None, showProgress=True, ): """ Perform STEM with scattering matrix streamed between RAM and GPU memory. This allows much larger fields of view to be calculated with relatively modest graphics card memory. The STEM raster is segmented into spatially close clusters and the probe positions in these clusters are processed sequentially, with the relevant part of the scattering matrix streamed from CPU to GPU memory. Parameters ---------- self : scattering_matrix The scattering matrix object. detectors : (Ndet, Y, X) array_like, optional Diffraction plane detectors to perform conventional STEM imaging. If None is passed then no conventional STEM images will be returned. fourD_STEM : bool or array_like, optional Pass fourD_STEM = True to perform 4D-STEM simulations. To save disk space a tuple containing pixel size and diffraction space extent of the datacube can be passed in. For example ([64,64],[1.2,1.2]) will output diffraction patterns measuring 64 x 64 pixels and 1.2 x 1.2 inverse Angstroms. datacube : (Ny, Nx, Y, X) array_like, optional datacube for 4D-STEM output, if None is passed (default) this will be initialized in the function. If a datacube is passed then the result will be added by the STEM routine (useful for multiple frozen phonon iterations) STEM_image : (Ndet,Ny,Nx) array_like, optional Array that will contain the conventional STEM images, if not passed will be initialized within the function. If it is passed then the result will be accumulated within the function, which is useful for multiple frozen phonon iterations. nstreams : int, optional Number of streams (seperate transfers from CPU to GPU memory). If None this will just be set to the product of the PRISM interpolation factor df : float, optional Defocus in Angstrom aberrations : list, optional A list containing a set of the class aberration, pass an empty list for an unaberrated contrast transfer function. ROI : (4,) array_like Fraction of the unit cell to be scanned. Should contain [y0,x0,y1,x1] where [x0,y0] and [x1,y1] are the bottom left and top right coordinates of the region of interest (ROI) expressed as a fraction of the total grid (or unit cell). device : torch.device, optional torch.device object which will determine which device (CPU or GPU) the calculations will run on. scan_posn : (...,2) array_like, optional Array containing the STEM scan positions in fractional coordinates. If provided scan_posn.shape[:-1] will give the shape of the STEM image. showProgress : str or bool, optional Pass False to disable progress readout, pass 'notebook' to get correct progress bar behaviour inside a jupyter notebook """ device = get_device(device) tdisable, tqdm = tqdm_handler(showProgress) # Get indices of PRISM cropping window crop_ = [x.cpu().numpy() for x in self.PRISM_crop_window()] # Make the STEM probe probe = focused_probe( self.gridshape, self.rsize[:2], self.eV, self.alpha_, df=df, aberrations=aberrations, app_units="invA", ) probe = cx_from_numpy(probe, device=device, dtype=self.dtype) # Make scan positions if none already provided if scan_posns is None: scan_posns = generate_STEM_raster( self.rsize, self.eV, self.alpha_, tiling=self.tiling, ROI=ROI, invA=True ) # Get scan (and STEM image) array shape and total number of scan positions scan_shape = scan_posns.shape[:-1] nscan = np.product(scan_shape) # Flatten scan positions to simplify iteration later on. scan_posns = scan_posns.reshape((nscan, 2)) # Calculate default 4D-STEM diffraction pattern sampling if FourD_STEM is True: GS = self.stored_gridshape FourD_STEM = [GS, GS / self.rsize[:2]] # Allocate diffraction pattern and STEM images if not already provided if FourD_STEM: gridout = workout_4DSTEM_datacube_DP_size( FourD_STEM, self.rsize, self.gridshape )[0] if (datacube is None) and FourD_STEM: datacube = np.zeros((*scan_shape, *gridout)) if not FourD_STEM: datacube = None # If detectors are provided then we are doing conventional STEM doConventionalSTEM = detectors is not None # Initialize STEM images if not provided if doConventionalSTEM: ndet = detectors.shape[0] if STEM_image is None: STEM_image = np.zeros((ndet, nscan)) else: STEM_image = STEM_image.reshape((ndet, nscan)) else: STEM_image = None if nstreams is None: # If the number of seperate streams is not suggested by the # user, make this equal to the product of the PRISM factor nstreams = int(np.product(self.PRISM_factor)) # Divide up the scan positions into clusters based on Euclidean # distance from sklearn.cluster import Birch if nscan > 1: model = Birch(threshold=0.01, n_clusters=nstreams) yhat = model.fit_predict(scan_posns) clusters = np.unique(yhat) else: yhat, clusters = [[0], [0]] # Now do STEM with each of the scan position clusters, streaming # only the necessary bits of the scattering matrix to the GPU. Datacube_segment = None STEM_image_segment = None FlatS = self.S.reshape((self.nbeams, np.prod(self.stored_gridshape), 2)) # Loop over probe positions clusters. This would be a good candidate # for multi-GPU work. for cluster in tqdm(clusters, desc="Probe position clusters", disable=tdisable): # Get map of probe positions in cluster points = np.nonzero(yhat == cluster)[0] npoints = len(points) # Get segments of images to update if doConventionalSTEM: STEM_image_segment = STEM_image[:, points] if FourD_STEM: Datacube_segment = np.zeros((1, npoints, 1, *gridout)) pix_posn = scan_posns[points] * np.asarray(self.stored_gridshape) # Work out bounds of the rectangular region of the scattering # matrix to stream to the GPU ymin, ymax = [ int(np.floor(np.amin(pix_posn[:, 0]) + crop_[0][0])), int(np.ceil(np.amax(pix_posn[:, 0]) + crop_[0][-1])), ] xmin, xmax = [ int(np.floor(np.amin(pix_posn[:, 1]) + crop_[1][0])), int(np.ceil(np.amax(pix_posn[:, 1]) + crop_[1][-1])), ] size = np.asarray([(ymax - ymin), (xmax - xmin)]) # Get indices of region of scattering matrix to stream to GPU window = [np.arange(a, b) for a, b in zip([ymin, xmin], [ymax, xmax])] indices = crop_window_to_flattened_indices_torch( window, self.stored_gridshape ) # Get segment of the scattering matrix to stream to GPU segmentshape = [len(x) for x in window] SegmentS = FlatS[:, indices, :].reshape((self.nbeams, *segmentshape, 2)) # Define a function that will map probe positions for the global # scattering matrix to their correct place on the smaller scattering # matrix streamed to the GPU. gshape = torch.as_tensor(self.stored_gridshape).to(device).type(self.dtype) Origin = torch.as_tensor([ymin, xmin]).to(device).type(self.dtype) segment_size = torch.as_tensor(size).to(device).type(self.dtype) def scan_transform(posn): return (posn * gshape - Origin) / segment_size # Keyword arguments to be passed to the __call__ function by the # STEM routine kwargs = {"Smat": SegmentS.to(device), "scan_transform": scan_transform} # Calculate STEM images STEM( self.rsize, probe, self.__call__, [1], self.eV, self.alpha_, detectors=detectors, FourD_STEM=FourD_STEM, datacube=Datacube_segment, scan_posn=scan_posns[points].reshape((npoints, 1, 2)), STEM_image=STEM_image_segment, method_kwargs=kwargs, showProgress=False, device=device, ) if doConventionalSTEM: STEM_image[:, points] += STEM_image_segment if FourD_STEM: for point, Dp in zip(points, Datacube_segment[0]): y, x = np.unravel_index(point, scan_shape) datacube[y, x] += Dp[0] # Unflatten 4D-STEM datacube scan dimensions, use numpy squeeze to # remove superfluous dimensions (ones with length 1) # if FourD_STEM: # datacube = datacube.reshape(*scan_shape, *datacube.shape[-2:]) if doConventionalSTEM: STEM_image = np.squeeze(STEM_image.reshape(ndet, *scan_shape)) # Return STEM images and datacube as a dictionary. If either of these # objects were not calculated the dictionary will contain None for those # entries. return {"STEM images": STEM_image, "datacube": datacube}