Module pyms.utils.torch_utils
A set of utility functions for working with pytorch tensors.
Expand source code
"""A set of utility functions for working with pytorch tensors."""
import torch
import numpy as np
from itertools import product
re = np.s_[..., 0]
im = np.s_[..., 1]
def iscomplex(a: torch.Tensor):
"""Return True if a is complex, False otherwise."""
return a.shape[-1] == 2
def check_complex(A):
"""Raise a RuntimeWarning if tensor A is not complex."""
for a in A:
if not iscomplex(a):
raise RuntimeWarning(
"taking complex_mul of non-complex tensor! a.shape " + str(a.shape)
)
def to_complex(real, imag=None):
"""Convert real and imaginary tensors to a complex tensor."""
if imag is None:
return torch.stack(
[real, torch.zeros(real.size(), dtype=real.dtype, device=real.device)], -1
)
else:
return torch.stack([real, imag], -1)
def get_device(device_type=None):
"""Initialize device cuda if available, CPU if no cuda is available."""
if device_type is None and torch.cuda.is_available():
device = torch.device("cuda")
elif device_type is None:
device = torch.device("cpu")
else:
device = torch.device(device_type)
return device
def complex_matmul(a: torch.Tensor, b: torch.Tensor, conjugate=False) -> torch.Tensor:
"""
Complex matrix multiplication of tensors a and b.
Pass conjugate = True to conjugate tensor b in the multiplication.
"""
check_complex([a, b])
are = a[re]
aim = a[im]
bre = b[re]
bim = b[im]
if conjugate:
real = are @ bre + aim @ bim
imag = -are @ bim + aim @ bre
else:
real = are @ bre - aim @ bim
imag = are @ bim + aim @ bre
return torch.stack([real, imag], -1)
def complex_mul(a: torch.Tensor, b: torch.Tensor, conjugate=False) -> torch.Tensor:
"""
Complex array multiplication of tensors a and b.
Pass conjugate = True to conjugate tensor b in the multiplication.
"""
check_complex([a, b])
are = a[re]
aim = a[im]
bre = b[re]
bim = b[im]
if conjugate:
real = are * bre + aim * bim
imag = -are * bim + aim * bre
else:
real = are * bre - aim * bim
imag = are * bim + aim * bre
return torch.stack([real, imag], -1)
def torch_c_exp(angle):
"""Calculate exp(1j*angle)."""
if angle.size()[-1] != 2:
# Case of a real exponent
result = torch.zeros(*angle.shape, 2, dtype=angle.dtype, device=angle.device)
result[re] = torch.cos(angle)
result[im] = torch.sin(angle)
else:
# Case of a complex valued exponent
exp = torch.exp(-angle[im])
result = torch.zeros(*angle.shape, dtype=angle.dtype, device=angle.device)
result[re] = exp * torch.cos(angle[re])
result[im] = exp * torch.sin(angle[re])
return result
def sinc(x):
"""Calculate the sinc function ie. sin(pi x)/(pi x)."""
y = torch.where(torch.abs(x) < 1.0e-20, torch.tensor([1.0e-20], dtype=x.dtype), x)
return torch.sin(np.pi * y) / np.pi / y
def ensure_torch_array(array, dtype=torch.float, device=None):
"""
Ensure that the input array is a pytorch tensor.
Converts to a pytorch array if input is a numpy array and do nothing if the
input is a pytorch tensor
"""
from .. import (
layered_structure_propagators,
layered_structure_transmission_function,
)
if device is None:
device = get_device(device)
if isinstance(array, torch.Tensor):
return array.to(device)
elif isinstance(array, layered_structure_transmission_function):
for i in range(len(array.Ts)):
array.Ts[i] = array.Ts[i].to(device)
return array
elif isinstance(array, layered_structure_propagators):
for i in range(len(array.Ps)):
array.Ps[i] = array.Ps[i].to(device)
return array
else:
if np.iscomplexobj(np.asarray(array)):
return cx_from_numpy(np.asarray(array), dtype=dtype, device=device)
else:
return torch.from_numpy(np.asarray(array)).type(dtype).to(device)
def amplitude(r):
"""
Calculate the amplitude of a complex tensor.
If the tensor is not complex then calculate square.
"""
if r.size(-1) == 2:
return r[..., 0] * r[..., 0] + r[..., 1] * r[..., 1]
else:
return r * r
# def roll_n(X, axis, n):
# """Roll a pytorch tensor X n entries along a given axis."""
# f_idx = tuple(
# slice(None, None, None) if i != axis % X.dim() else slice(0, n, None)
# for i in range(X.dim())
# )
# b_idx = tuple(
# slice(None, None, None) if i != axis % X.dim() else slice(n, None, None)
# for i in range(X.dim())
# )
# front = X[f_idx]
# back = X[b_idx]
# return torch.cat([back, front], axis)
def cx_from_numpy(
x: np.array, dtype=torch.float32, device=get_device()
) -> torch.Tensor:
"""
Turn a complex numpy array into the required pytorch array format.
Parameters
----------
x : complex np.ndarray
A complex numpy array
Keyword arguments
-----------------
dtype : torch.dtype
The datatype of the output array
device : torch.device
The device (CPU or GPU) of the output array
"""
if "complex" in str(x.dtype):
out = torch.zeros(*x.shape, 2)
out[re] = torch.from_numpy(x.real)
out[im] = torch.from_numpy(x.imag)
else:
if x.shape[-1] != 2:
out = torch.zeros(x.shape + (2,))
out[re] = torch.from_numpy(x.real)
else:
out = torch.zeros(x.shape + (2,))
out[re] = torch.from_numpy(x[re])
out[im] = torch.from_numpy(x[im])
return out.to(device).type(dtype)
def cx_to_numpy(x: torch.Tensor) -> np.ndarray:
"""Convert a complex pytorch tensor to a complex numpy array."""
check_complex(x)
return x[re].cpu().numpy() + 1j * x[im].cpu().numpy()
def fftfreq(n, dtype=torch.float, device=torch.device("cpu")):
"""
Generate an array of Fourier coordinates in units of pixels.
Same as numpy.fft.fftfreq(n)*n but for a torch array.
"""
return (torch.arange(n, dtype=dtype, device=device) + n // 2) % n - n // 2
def torch_dtype_to_numpy(dtype):
"""Convert a torch datatype to a numpy datatype."""
scratch_array = torch.zeros(1, dtype=dtype)
return scratch_array.cpu().numpy().dtype
def fourier_shift_array_1d(
y, posn, dtype=torch.float, device=torch.device("cpu"), units="pixels"
):
"""Apply Fourier shift theorem for sub-pixel shift to a 1 dimensional array."""
ramp = torch.empty(y, 2, dtype=dtype, device=device)
ky = 2 * np.pi * fftfreq(y) * posn
if units == "pixels":
ky /= y
ramp[..., 0] = torch.cos(ky)
ramp[..., 1] = -torch.sin(ky)
return ramp
def fourier_shift_torch(
array,
posn,
dtype=torch.float32,
device=torch.device("cpu"),
qspace_in=False,
qspace_out=False,
units="pixels",
):
"""
Apply Fourier shift theorem for sub-pixel shifts to array.
Parameters
-----------
array : torch.tensor (...,Y,X,2)
Complex array to be Fourier shifted
posn : torch.tensor (K x 2) or (2,)
Shift(s) to be applied
"""
if not qspace_in:
array = torch.fft(array, signal_ndim=2)
array = complex_mul(
array,
fourier_shift_array(
array.size()[-3:-1],
posn,
dtype=array.dtype,
device=array.device,
units=units,
),
)
if qspace_out:
return array
return torch.ifft(array, signal_ndim=2)
def fourier_shift_array(
size, posn, dtype=torch.float, device=torch.device("cpu"), units="pixels"
):
"""
Create Fourier shift theorem array to (pixel) position given by list posn.
Parameters
----------
size : array_like
size of the array (Y,X)
posn : array_like
can be a K x 2 array to give a K x Y x X shift arrays
posn
"""
# Get number of dimensions
nn = len(posn.shape)
# Get size of array
y, x = size
if nn == 1:
# Make y ramp exp(-2pi i ky y)
yramp = fourier_shift_array_1d(
y, posn[0], units=units, dtype=dtype, device=device
)
# Make y ramp exp(-2pi i kx x)
xramp = fourier_shift_array_1d(
x, posn[1], units=units, dtype=dtype, device=device
)
# Multiply both arrays together, view statements for
# appropriate broadcasting to 2D
return complex_mul(yramp.view(y, 1, 2), xramp.view(1, x, 2))
else:
K = posn.shape[0]
# Make y ramp exp(-2pi i ky y)
yramp = torch.empty(K, y, 2, dtype=dtype, device=device)
ky = (
2
* np.pi
* fftfreq(y, dtype=dtype, device=device).view(1, y)
* posn[:, 0].view(K, 1)
)
if units == "pixels":
ky /= y
yramp[..., 0] = torch.cos(ky)
yramp[..., 1] = -torch.sin(ky)
# Make y ramp exp(-2pi i kx x)
xramp = torch.empty(K, x, 2, dtype=dtype, device=device)
kx = (
2
* np.pi
* fftfreq(x, dtype=dtype, device=device).view(1, x)
* posn[:, 1].view(K, 1)
)
if units == "pixels":
kx /= x
xramp[..., 0] = torch.cos(kx)
xramp[..., 1] = -torch.sin(kx)
# Multiply both arrays together, view statements for
# appropriate broadcasting to 2D
return complex_mul(yramp.view(K, y, 1, 2), xramp.view(K, 1, x, 2))
def crop_window_to_periodic_indices(win, shape):
"""
Create indices for a rectangular subset of a larger array.
If indices exceed the size of the larger array then these indices will wrap
around to the other side of the grid providing two or more rectangular
subsets of the larger array. Designed to be used in conjunction with
the torch.narrow function to choose subsets of the square array to evaluate
the PRISM algorithm on.
Assumes that the requested window is smaller than the array size
Parameters
----------
win : (4,) array_like
contains (y0,y,x0,x) the lower y index and y length and lower x index
and x length
shape : (2,) array_like
Shape of the larger array
Examples
--------
>>>> crop_window_to_periodic_indices([2,2,1,3],[5,5])
(([2,2],[1,3]),)
>>>> crop_window_to_periodic_indices([-1,3,1,3],[5,5])
(([4,1],[1,3]),([0,2],[1,3]))
>>>> crop_window_to_periodic_indices([4,4,1,3],[5,5])
(([4,1],[1,3]),([0,3],[1,3]))
>>>> list(crop_window_to_periodic_indices([4,4,3,3],[5,5]))
(([4,1],[3,2]),([0,3],[3,2]),([4,1],[0,1]),([0,3],[0,1]))
"""
def oneDindices(start, step, bound):
if start + step > bound - 1:
return [start, bound - start], [0, start + step - bound]
elif start < 0:
return [start % bound, bound - start % bound], [0, (start + step) % bound]
else:
return [[start, step]]
y = oneDindices(*win[:2], shape[0])
x = oneDindices(*win[2:], shape[1])
return tuple(product(y, x))
def crop_window_to_flattened_indices_torch(indices: torch.Tensor, shape: list):
"""
Create (flattened) indices for a rectangular subset of a larger array.
Useful, for example for scattering matrix calculations where only a rectangular
subset of the array is used in the PRISM interpolation routine
Array indices exceeding the bounds of the array are wrapped to be consistent
with periodic boundary conditions.
Parameters
----------
indices : torch.Tensor
The centers of each of the cropping windows
shape : array_like
Size of the cropping windows
Examples
--------
>>> indices = torch.as_tensor([[2,3,4],[1,2,3]])
>>> gridshape = [4,4]
>>> win = [3,3]
>>> grid = torch.zeros(gridshape,dtype=torch.Long)
tensor([[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]])
>>> grid = grid.flatten()
>>> ind = pyms.utils.crop_window_to_flattened_indices_torch(indices,gridshape)
>>> grid[ind] = 1
>>> grid.view(gridshape)
tensor([[0, 1, 1, 1],
[0, 0, 0, 0],
[0, 1, 1, 1],
[0, 1, 1, 1]])
"""
xind = torch.as_tensor(indices[-1]).view(1, len(indices[-1])) % shape[-1]
yind = torch.as_tensor(indices[-2]).view(len(indices[-2]), 1) % shape[-2]
return (xind + yind * shape[-1]).flatten().type(torch.LongTensor)
def crop_to_bandwidth_limit_torch(
array: torch.Tensor,
limit=2 / 3,
qspace_in=True,
qspace_out=True,
norm="conserve_L2",
):
"""Crop an array to its bandwidth limit (remove superfluous array entries)."""
# Check if array is complex or not
complx = iscomplex(array)
# Get array shape, taking into account final dimension of size 2 if the array
# is complex
gridshape = array.shape[-2 - int(complx) :][:2]
# New shape of final dimensions
newshape = tuple([int(round(gridshape[i] * limit)) for i in range(2)])
return fourier_interpolate_2d_torch(
array, newshape, norm=norm, qspace_in=qspace_in, qspace_out=qspace_out
)
def size_of_bandwidth_limited_array(shape):
"""Get the size of an array after band-width limiting."""
return list(crop_to_bandwidth_limit_torch(torch.zeros(*shape)).size())
def detect(detector, diffraction_pattern):
"""
Apply a detector to a diffraction pattern.
Calculates the signal in a diffraction pattern detector even if the size
of the diffraction pattern and the detector are mismatched, assumes that
the zeroth coordinate in reciprocal space is in the top-left hand corner
of the array.
"""
minsize = min(detector.size()[-2:], diffraction_pattern.size()[-2:])
wind = [fftfreq(minsize[i], torch.long, detector.device) for i in [0, 1]]
Dwind = crop_window_to_flattened_indices_torch(wind, detector.size())
DPwind = crop_window_to_flattened_indices_torch(wind, diffraction_pattern.size())
return torch.sum(
detector.flatten(-2, -1)[:, None, Dwind]
* diffraction_pattern.flatten(-2, -1)[None, :, DPwind],
dim=-1,
)
def fourier_interpolate_2d_torch(
ain, shapeout, norm="conserve_val", qspace_in=False, qspace_out=False
):
"""
Fourier interpolation of array ain to shape shapeout.
If shapeout is smaller than ain.shape then Fourier downsampling is
performed
Parameters
----------
ain : (...,Ny,Nx,2) torch.tensor
Input array
shapeout : (2,) array_like
Shape of output array
norm : str, optional {'conserve_val','conserve_norm','conserve_L1'}
Normalization of output. If 'conserve_val' then array values are preserved
if 'conserve_norm' L2 norm is conserved under interpolation and if
'conserve_L1' L1 norm is conserved under interpolation
qspace_in : bool, optional
If True expect a Fourier space input, otherwise (default) expect a
real space input
qspace_out : bool, optional
If True return a Fourier space output, otherwise (default) return in
real space
"""
dtype = ain.dtype
inputComplex = iscomplex(ain)
# Make input complex
aout = torch.zeros(
ain.shape[: -2 - int(inputComplex)] + (np.prod(shapeout), 2),
dtype=dtype,
device=ain.device,
)
# Get input dimensions
npiyin, npixin = ain.size()[-2 - int(inputComplex) :][:2]
npiyout, npixout = shapeout
# Get Fourier interpolation masks
# PyTorch does not yet do element-wise logic operations, so we have to do
# this bit in numpy. Additionally, in Windows pytorch does not support
# bool types so we have to convert this to a unsigned 8-bit integer.
from .numpy_utils import Fourier_interpolation_masks
maskin, maskout = [
torch.from_numpy(x).flatten()
for x in Fourier_interpolation_masks(npiyin, npixin, npiyout, npixout)
]
# Now transfer over Fourier coefficients from input to output array
if inputComplex:
ain_ = ain
else:
ain_ = to_complex(ain)
if not qspace_in:
ain_ = torch.fft(ain_, signal_ndim=2)
aout[..., maskout, :] = ain_.flatten(-3, -2)[..., maskin, :]
# Fourier transform result with appropriate normalization
if norm == "conserve_val":
factor = npiyout * npixout / (npiyin * npixin)
elif norm == "conserve_norm":
factor = np.sqrt(npiyout * npixout / (npiyin * npixin))
else:
factor = 1
# Fourier transform result with appropriate normalization
aout = factor * aout.reshape(
ain.shape[: -2 - int(inputComplex)] + tuple(shapeout) + (2,)
)
if not qspace_out:
aout = torch.ifft(aout, signal_ndim=2)
# Return correct array data type
if inputComplex:
return aout
return aout[re]
def crop_torch(arrayin, shapeout):
"""
Crop the last two dimensions of arrayin to grid size shapeout.
For entries of shapeout which are larger than the shape of the input array,
perform zero-padding.
"""
C = iscomplex(arrayin)
# Number of dimensions in input array
ndim = arrayin.ndim
# Number of dimensions not covered by shapeout (ie not to be cropped)
nUntouched = ndim - 2 - C
# Shape of output array
shapeout_ = arrayin.shape[:nUntouched] + tuple(shapeout)
if C:
shapeout_ += (2,)
arrayout = torch.zeros(shapeout_, dtype=arrayin.dtype, device=arrayin.device)
y, x = arrayin.shape[-2 - C :][:2]
y_, x_ = shapeout[-2:]
def indices(y, y_):
if y > y_:
# Crop in y dimension
y1, y2 = [(y - y_) // 2, (y + y_) // 2]
y1_, y2_ = [0, y_]
else:
# Zero pad in y dimension
y1, y2 = [0, y]
y1_, y2_ = [(y_ - y) // 2, (y + y_) // 2]
return y1, y2, y1_, y2_
y1, y2, y1_, y2_ = indices(y, y_)
x1, x2, x1_, x2_ = indices(x, x_)
if C:
arrayout[..., y1_:y2_, x1_:x2_, :] = arrayin[..., y1:y2, x1:x2, :]
else:
arrayout[..., y1_:y2_, x1_:x2_] = arrayin[..., y1:y2, x1:x2]
return arrayout
Functions
def amplitude(r)
-
Calculate the amplitude of a complex tensor.
If the tensor is not complex then calculate square.
Expand source code
def amplitude(r): """ Calculate the amplitude of a complex tensor. If the tensor is not complex then calculate square. """ if r.size(-1) == 2: return r[..., 0] * r[..., 0] + r[..., 1] * r[..., 1] else: return r * r
def check_complex(A)
-
Raise a RuntimeWarning if tensor A is not complex.
Expand source code
def check_complex(A): """Raise a RuntimeWarning if tensor A is not complex.""" for a in A: if not iscomplex(a): raise RuntimeWarning( "taking complex_mul of non-complex tensor! a.shape " + str(a.shape) )
def complex_matmul(a: torch.Tensor, b: torch.Tensor, conjugate=False) ‑> torch.Tensor
-
Complex matrix multiplication of tensors a and b.
Pass conjugate = True to conjugate tensor b in the multiplication.
Expand source code
def complex_matmul(a: torch.Tensor, b: torch.Tensor, conjugate=False) -> torch.Tensor: """ Complex matrix multiplication of tensors a and b. Pass conjugate = True to conjugate tensor b in the multiplication. """ check_complex([a, b]) are = a[re] aim = a[im] bre = b[re] bim = b[im] if conjugate: real = are @ bre + aim @ bim imag = -are @ bim + aim @ bre else: real = are @ bre - aim @ bim imag = are @ bim + aim @ bre return torch.stack([real, imag], -1)
def complex_mul(a: torch.Tensor, b: torch.Tensor, conjugate=False) ‑> torch.Tensor
-
Complex array multiplication of tensors a and b.
Pass conjugate = True to conjugate tensor b in the multiplication.
Expand source code
def complex_mul(a: torch.Tensor, b: torch.Tensor, conjugate=False) -> torch.Tensor: """ Complex array multiplication of tensors a and b. Pass conjugate = True to conjugate tensor b in the multiplication. """ check_complex([a, b]) are = a[re] aim = a[im] bre = b[re] bim = b[im] if conjugate: real = are * bre + aim * bim imag = -are * bim + aim * bre else: real = are * bre - aim * bim imag = are * bim + aim * bre return torch.stack([real, imag], -1)
def crop_to_bandwidth_limit_torch(array: torch.Tensor, limit=0.6666666666666666, qspace_in=True, qspace_out=True, norm='conserve_L2')
-
Crop an array to its bandwidth limit (remove superfluous array entries).
Expand source code
def crop_to_bandwidth_limit_torch( array: torch.Tensor, limit=2 / 3, qspace_in=True, qspace_out=True, norm="conserve_L2", ): """Crop an array to its bandwidth limit (remove superfluous array entries).""" # Check if array is complex or not complx = iscomplex(array) # Get array shape, taking into account final dimension of size 2 if the array # is complex gridshape = array.shape[-2 - int(complx) :][:2] # New shape of final dimensions newshape = tuple([int(round(gridshape[i] * limit)) for i in range(2)]) return fourier_interpolate_2d_torch( array, newshape, norm=norm, qspace_in=qspace_in, qspace_out=qspace_out )
def crop_torch(arrayin, shapeout)
-
Crop the last two dimensions of arrayin to grid size shapeout.
For entries of shapeout which are larger than the shape of the input array, perform zero-padding.
Expand source code
def crop_torch(arrayin, shapeout): """ Crop the last two dimensions of arrayin to grid size shapeout. For entries of shapeout which are larger than the shape of the input array, perform zero-padding. """ C = iscomplex(arrayin) # Number of dimensions in input array ndim = arrayin.ndim # Number of dimensions not covered by shapeout (ie not to be cropped) nUntouched = ndim - 2 - C # Shape of output array shapeout_ = arrayin.shape[:nUntouched] + tuple(shapeout) if C: shapeout_ += (2,) arrayout = torch.zeros(shapeout_, dtype=arrayin.dtype, device=arrayin.device) y, x = arrayin.shape[-2 - C :][:2] y_, x_ = shapeout[-2:] def indices(y, y_): if y > y_: # Crop in y dimension y1, y2 = [(y - y_) // 2, (y + y_) // 2] y1_, y2_ = [0, y_] else: # Zero pad in y dimension y1, y2 = [0, y] y1_, y2_ = [(y_ - y) // 2, (y + y_) // 2] return y1, y2, y1_, y2_ y1, y2, y1_, y2_ = indices(y, y_) x1, x2, x1_, x2_ = indices(x, x_) if C: arrayout[..., y1_:y2_, x1_:x2_, :] = arrayin[..., y1:y2, x1:x2, :] else: arrayout[..., y1_:y2_, x1_:x2_] = arrayin[..., y1:y2, x1:x2] return arrayout
def crop_window_to_flattened_indices_torch(indices: torch.Tensor, shape: list)
-
Create (flattened) indices for a rectangular subset of a larger array.
Useful, for example for scattering matrix calculations where only a rectangular subset of the array is used in the PRISM interpolation routine
Array indices exceeding the bounds of the array are wrapped to be consistent with periodic boundary conditions.
Parameters
indices
:torch.Tensor
- The centers of each of the cropping windows
shape
:array_like
- Size of the cropping windows
Examples
>>> indices = torch.as_tensor([[2,3,4],[1,2,3]]) >>> gridshape = [4,4] >>> win = [3,3] >>> grid = torch.zeros(gridshape,dtype=torch.Long) tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) >>> grid = grid.flatten() >>> ind = pyms.utils.crop_window_to_flattened_indices_torch(indices,gridshape) >>> grid[ind] = 1 >>> grid.view(gridshape) tensor([[0, 1, 1, 1], [0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 1, 1]])
Expand source code
def crop_window_to_flattened_indices_torch(indices: torch.Tensor, shape: list): """ Create (flattened) indices for a rectangular subset of a larger array. Useful, for example for scattering matrix calculations where only a rectangular subset of the array is used in the PRISM interpolation routine Array indices exceeding the bounds of the array are wrapped to be consistent with periodic boundary conditions. Parameters ---------- indices : torch.Tensor The centers of each of the cropping windows shape : array_like Size of the cropping windows Examples -------- >>> indices = torch.as_tensor([[2,3,4],[1,2,3]]) >>> gridshape = [4,4] >>> win = [3,3] >>> grid = torch.zeros(gridshape,dtype=torch.Long) tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) >>> grid = grid.flatten() >>> ind = pyms.utils.crop_window_to_flattened_indices_torch(indices,gridshape) >>> grid[ind] = 1 >>> grid.view(gridshape) tensor([[0, 1, 1, 1], [0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 1, 1]]) """ xind = torch.as_tensor(indices[-1]).view(1, len(indices[-1])) % shape[-1] yind = torch.as_tensor(indices[-2]).view(len(indices[-2]), 1) % shape[-2] return (xind + yind * shape[-1]).flatten().type(torch.LongTensor)
def crop_window_to_periodic_indices(win, shape)
-
Create indices for a rectangular subset of a larger array.
If indices exceed the size of the larger array then these indices will wrap around to the other side of the grid providing two or more rectangular subsets of the larger array. Designed to be used in conjunction with the torch.narrow function to choose subsets of the square array to evaluate the PRISM algorithm on.
Assumes that the requested window is smaller than the array size
Parameters
win
:(4,) array_like
- contains (y0,y,x0,x) the lower y index and y length and lower x index and x length
shape
:(2,) array_like
- Shape of the larger array
Examples
crop_window_to_periodic_indices([2,2,1,3],[5,5]) (([2,2],[1,3]),) crop_window_to_periodic_indices([-1,3,1,3],[5,5]) (([4,1],[1,3]),([0,2],[1,3])) crop_window_to_periodic_indices([4,4,1,3],[5,5]) (([4,1],[1,3]),([0,3],[1,3])) list(crop_window_to_periodic_indices([4,4,3,3],[5,5])) (([4,1],[3,2]),([0,3],[3,2]),([4,1],[0,1]),([0,3],[0,1]))
Expand source code
def crop_window_to_periodic_indices(win, shape): """ Create indices for a rectangular subset of a larger array. If indices exceed the size of the larger array then these indices will wrap around to the other side of the grid providing two or more rectangular subsets of the larger array. Designed to be used in conjunction with the torch.narrow function to choose subsets of the square array to evaluate the PRISM algorithm on. Assumes that the requested window is smaller than the array size Parameters ---------- win : (4,) array_like contains (y0,y,x0,x) the lower y index and y length and lower x index and x length shape : (2,) array_like Shape of the larger array Examples -------- >>>> crop_window_to_periodic_indices([2,2,1,3],[5,5]) (([2,2],[1,3]),) >>>> crop_window_to_periodic_indices([-1,3,1,3],[5,5]) (([4,1],[1,3]),([0,2],[1,3])) >>>> crop_window_to_periodic_indices([4,4,1,3],[5,5]) (([4,1],[1,3]),([0,3],[1,3])) >>>> list(crop_window_to_periodic_indices([4,4,3,3],[5,5])) (([4,1],[3,2]),([0,3],[3,2]),([4,1],[0,1]),([0,3],[0,1])) """ def oneDindices(start, step, bound): if start + step > bound - 1: return [start, bound - start], [0, start + step - bound] elif start < 0: return [start % bound, bound - start % bound], [0, (start + step) % bound] else: return [[start, step]] y = oneDindices(*win[:2], shape[0]) x = oneDindices(*win[2:], shape[1]) return tuple(product(y, x))
def cx_from_numpy(x:
, dtype=torch.float32, device=device(type='cuda')) ‑> torch.Tensor -
Turn a complex numpy array into the required pytorch array format.
Parameters
x
:complex np.ndarray
- A complex numpy array
Keyword Arguments
dtype : torch.dtype The datatype of the output array device : torch.device The device (CPU or GPU) of the output array
Expand source code
def cx_from_numpy( x: np.array, dtype=torch.float32, device=get_device() ) -> torch.Tensor: """ Turn a complex numpy array into the required pytorch array format. Parameters ---------- x : complex np.ndarray A complex numpy array Keyword arguments ----------------- dtype : torch.dtype The datatype of the output array device : torch.device The device (CPU or GPU) of the output array """ if "complex" in str(x.dtype): out = torch.zeros(*x.shape, 2) out[re] = torch.from_numpy(x.real) out[im] = torch.from_numpy(x.imag) else: if x.shape[-1] != 2: out = torch.zeros(x.shape + (2,)) out[re] = torch.from_numpy(x.real) else: out = torch.zeros(x.shape + (2,)) out[re] = torch.from_numpy(x[re]) out[im] = torch.from_numpy(x[im]) return out.to(device).type(dtype)
def cx_to_numpy(x: torch.Tensor) ‑> numpy.ndarray
-
Convert a complex pytorch tensor to a complex numpy array.
Expand source code
def cx_to_numpy(x: torch.Tensor) -> np.ndarray: """Convert a complex pytorch tensor to a complex numpy array.""" check_complex(x) return x[re].cpu().numpy() + 1j * x[im].cpu().numpy()
def detect(detector, diffraction_pattern)
-
Apply a detector to a diffraction pattern.
Calculates the signal in a diffraction pattern detector even if the size of the diffraction pattern and the detector are mismatched, assumes that the zeroth coordinate in reciprocal space is in the top-left hand corner of the array.
Expand source code
def detect(detector, diffraction_pattern): """ Apply a detector to a diffraction pattern. Calculates the signal in a diffraction pattern detector even if the size of the diffraction pattern and the detector are mismatched, assumes that the zeroth coordinate in reciprocal space is in the top-left hand corner of the array. """ minsize = min(detector.size()[-2:], diffraction_pattern.size()[-2:]) wind = [fftfreq(minsize[i], torch.long, detector.device) for i in [0, 1]] Dwind = crop_window_to_flattened_indices_torch(wind, detector.size()) DPwind = crop_window_to_flattened_indices_torch(wind, diffraction_pattern.size()) return torch.sum( detector.flatten(-2, -1)[:, None, Dwind] * diffraction_pattern.flatten(-2, -1)[None, :, DPwind], dim=-1, )
def ensure_torch_array(array, dtype=torch.float32, device=None)
-
Ensure that the input array is a pytorch tensor.
Converts to a pytorch array if input is a numpy array and do nothing if the input is a pytorch tensor
Expand source code
def ensure_torch_array(array, dtype=torch.float, device=None): """ Ensure that the input array is a pytorch tensor. Converts to a pytorch array if input is a numpy array and do nothing if the input is a pytorch tensor """ from .. import ( layered_structure_propagators, layered_structure_transmission_function, ) if device is None: device = get_device(device) if isinstance(array, torch.Tensor): return array.to(device) elif isinstance(array, layered_structure_transmission_function): for i in range(len(array.Ts)): array.Ts[i] = array.Ts[i].to(device) return array elif isinstance(array, layered_structure_propagators): for i in range(len(array.Ps)): array.Ps[i] = array.Ps[i].to(device) return array else: if np.iscomplexobj(np.asarray(array)): return cx_from_numpy(np.asarray(array), dtype=dtype, device=device) else: return torch.from_numpy(np.asarray(array)).type(dtype).to(device)
def fftfreq(n, dtype=torch.float32, device=device(type='cpu'))
-
Generate an array of Fourier coordinates in units of pixels.
Same as numpy.fft.fftfreq(n)*n but for a torch array.
Expand source code
def fftfreq(n, dtype=torch.float, device=torch.device("cpu")): """ Generate an array of Fourier coordinates in units of pixels. Same as numpy.fft.fftfreq(n)*n but for a torch array. """ return (torch.arange(n, dtype=dtype, device=device) + n // 2) % n - n // 2
def fourier_interpolate_2d_torch(ain, shapeout, norm='conserve_val', qspace_in=False, qspace_out=False)
-
Fourier interpolation of array ain to shape shapeout.
If shapeout is smaller than ain.shape then Fourier downsampling is performed
Parameters
ain
:(…,Ny,Nx,2) torch.tensor
- Input array
shapeout
:(2,) array_like
- Shape of output array
norm
:str
, optional{'conserve_val','conserve_norm','conserve_L1'}
- Normalization of output. If 'conserve_val' then array values are preserved if 'conserve_norm' L2 norm is conserved under interpolation and if 'conserve_L1' L1 norm is conserved under interpolation
qspace_in
:bool
, optional- If True expect a Fourier space input, otherwise (default) expect a real space input
qspace_out
:bool
, optional- If True return a Fourier space output, otherwise (default) return in real space
Expand source code
def fourier_interpolate_2d_torch( ain, shapeout, norm="conserve_val", qspace_in=False, qspace_out=False ): """ Fourier interpolation of array ain to shape shapeout. If shapeout is smaller than ain.shape then Fourier downsampling is performed Parameters ---------- ain : (...,Ny,Nx,2) torch.tensor Input array shapeout : (2,) array_like Shape of output array norm : str, optional {'conserve_val','conserve_norm','conserve_L1'} Normalization of output. If 'conserve_val' then array values are preserved if 'conserve_norm' L2 norm is conserved under interpolation and if 'conserve_L1' L1 norm is conserved under interpolation qspace_in : bool, optional If True expect a Fourier space input, otherwise (default) expect a real space input qspace_out : bool, optional If True return a Fourier space output, otherwise (default) return in real space """ dtype = ain.dtype inputComplex = iscomplex(ain) # Make input complex aout = torch.zeros( ain.shape[: -2 - int(inputComplex)] + (np.prod(shapeout), 2), dtype=dtype, device=ain.device, ) # Get input dimensions npiyin, npixin = ain.size()[-2 - int(inputComplex) :][:2] npiyout, npixout = shapeout # Get Fourier interpolation masks # PyTorch does not yet do element-wise logic operations, so we have to do # this bit in numpy. Additionally, in Windows pytorch does not support # bool types so we have to convert this to a unsigned 8-bit integer. from .numpy_utils import Fourier_interpolation_masks maskin, maskout = [ torch.from_numpy(x).flatten() for x in Fourier_interpolation_masks(npiyin, npixin, npiyout, npixout) ] # Now transfer over Fourier coefficients from input to output array if inputComplex: ain_ = ain else: ain_ = to_complex(ain) if not qspace_in: ain_ = torch.fft(ain_, signal_ndim=2) aout[..., maskout, :] = ain_.flatten(-3, -2)[..., maskin, :] # Fourier transform result with appropriate normalization if norm == "conserve_val": factor = npiyout * npixout / (npiyin * npixin) elif norm == "conserve_norm": factor = np.sqrt(npiyout * npixout / (npiyin * npixin)) else: factor = 1 # Fourier transform result with appropriate normalization aout = factor * aout.reshape( ain.shape[: -2 - int(inputComplex)] + tuple(shapeout) + (2,) ) if not qspace_out: aout = torch.ifft(aout, signal_ndim=2) # Return correct array data type if inputComplex: return aout return aout[re]
def fourier_shift_array(size, posn, dtype=torch.float32, device=device(type='cpu'), units='pixels')
-
Create Fourier shift theorem array to (pixel) position given by list posn.
Parameters
size
:array_like
- size of the array (Y,X)
posn
:array_like
- can be a K x 2 array to give a K x Y x X shift arrays
posn
Expand source code
def fourier_shift_array( size, posn, dtype=torch.float, device=torch.device("cpu"), units="pixels" ): """ Create Fourier shift theorem array to (pixel) position given by list posn. Parameters ---------- size : array_like size of the array (Y,X) posn : array_like can be a K x 2 array to give a K x Y x X shift arrays posn """ # Get number of dimensions nn = len(posn.shape) # Get size of array y, x = size if nn == 1: # Make y ramp exp(-2pi i ky y) yramp = fourier_shift_array_1d( y, posn[0], units=units, dtype=dtype, device=device ) # Make y ramp exp(-2pi i kx x) xramp = fourier_shift_array_1d( x, posn[1], units=units, dtype=dtype, device=device ) # Multiply both arrays together, view statements for # appropriate broadcasting to 2D return complex_mul(yramp.view(y, 1, 2), xramp.view(1, x, 2)) else: K = posn.shape[0] # Make y ramp exp(-2pi i ky y) yramp = torch.empty(K, y, 2, dtype=dtype, device=device) ky = ( 2 * np.pi * fftfreq(y, dtype=dtype, device=device).view(1, y) * posn[:, 0].view(K, 1) ) if units == "pixels": ky /= y yramp[..., 0] = torch.cos(ky) yramp[..., 1] = -torch.sin(ky) # Make y ramp exp(-2pi i kx x) xramp = torch.empty(K, x, 2, dtype=dtype, device=device) kx = ( 2 * np.pi * fftfreq(x, dtype=dtype, device=device).view(1, x) * posn[:, 1].view(K, 1) ) if units == "pixels": kx /= x xramp[..., 0] = torch.cos(kx) xramp[..., 1] = -torch.sin(kx) # Multiply both arrays together, view statements for # appropriate broadcasting to 2D return complex_mul(yramp.view(K, y, 1, 2), xramp.view(K, 1, x, 2))
def fourier_shift_array_1d(y, posn, dtype=torch.float32, device=device(type='cpu'), units='pixels')
-
Apply Fourier shift theorem for sub-pixel shift to a 1 dimensional array.
Expand source code
def fourier_shift_array_1d( y, posn, dtype=torch.float, device=torch.device("cpu"), units="pixels" ): """Apply Fourier shift theorem for sub-pixel shift to a 1 dimensional array.""" ramp = torch.empty(y, 2, dtype=dtype, device=device) ky = 2 * np.pi * fftfreq(y) * posn if units == "pixels": ky /= y ramp[..., 0] = torch.cos(ky) ramp[..., 1] = -torch.sin(ky) return ramp
def fourier_shift_torch(array, posn, dtype=torch.float32, device=device(type='cpu'), qspace_in=False, qspace_out=False, units='pixels')
-
Apply Fourier shift theorem for sub-pixel shifts to array.
Parameters
array
:torch.tensor (…,Y,X,2)
- Complex array to be Fourier shifted
posn
:torch.tensor (K x 2)
or(2,)
- Shift(s) to be applied
Expand source code
def fourier_shift_torch( array, posn, dtype=torch.float32, device=torch.device("cpu"), qspace_in=False, qspace_out=False, units="pixels", ): """ Apply Fourier shift theorem for sub-pixel shifts to array. Parameters ----------- array : torch.tensor (...,Y,X,2) Complex array to be Fourier shifted posn : torch.tensor (K x 2) or (2,) Shift(s) to be applied """ if not qspace_in: array = torch.fft(array, signal_ndim=2) array = complex_mul( array, fourier_shift_array( array.size()[-3:-1], posn, dtype=array.dtype, device=array.device, units=units, ), ) if qspace_out: return array return torch.ifft(array, signal_ndim=2)
def get_device(device_type=None)
-
Initialize device cuda if available, CPU if no cuda is available.
Expand source code
def get_device(device_type=None): """Initialize device cuda if available, CPU if no cuda is available.""" if device_type is None and torch.cuda.is_available(): device = torch.device("cuda") elif device_type is None: device = torch.device("cpu") else: device = torch.device(device_type) return device
def iscomplex(a: torch.Tensor)
-
Return True if a is complex, False otherwise.
Expand source code
def iscomplex(a: torch.Tensor): """Return True if a is complex, False otherwise.""" return a.shape[-1] == 2
def sinc(x)
-
Calculate the sinc function ie. sin(pi x)/(pi x).
Expand source code
def sinc(x): """Calculate the sinc function ie. sin(pi x)/(pi x).""" y = torch.where(torch.abs(x) < 1.0e-20, torch.tensor([1.0e-20], dtype=x.dtype), x) return torch.sin(np.pi * y) / np.pi / y
def size_of_bandwidth_limited_array(shape)
-
Get the size of an array after band-width limiting.
Expand source code
def size_of_bandwidth_limited_array(shape): """Get the size of an array after band-width limiting.""" return list(crop_to_bandwidth_limit_torch(torch.zeros(*shape)).size())
def to_complex(real, imag=None)
-
Convert real and imaginary tensors to a complex tensor.
Expand source code
def to_complex(real, imag=None): """Convert real and imaginary tensors to a complex tensor.""" if imag is None: return torch.stack( [real, torch.zeros(real.size(), dtype=real.dtype, device=real.device)], -1 ) else: return torch.stack([real, imag], -1)
def torch_c_exp(angle)
-
Calculate exp(1j*angle).
Expand source code
def torch_c_exp(angle): """Calculate exp(1j*angle).""" if angle.size()[-1] != 2: # Case of a real exponent result = torch.zeros(*angle.shape, 2, dtype=angle.dtype, device=angle.device) result[re] = torch.cos(angle) result[im] = torch.sin(angle) else: # Case of a complex valued exponent exp = torch.exp(-angle[im]) result = torch.zeros(*angle.shape, dtype=angle.dtype, device=angle.device) result[re] = exp * torch.cos(angle[re]) result[im] = exp * torch.sin(angle[re]) return result
def torch_dtype_to_numpy(dtype)
-
Convert a torch datatype to a numpy datatype.
Expand source code
def torch_dtype_to_numpy(dtype): """Convert a torch datatype to a numpy datatype.""" scratch_array = torch.zeros(1, dtype=dtype) return scratch_array.cpu().numpy().dtype