Utility functions¶
gprutils.py¶
Utility functions for the analysis of sparse image and hyperspectral data with Gaussian processes.
Author: Maxim Ziatdinov (email: maxim.ziatdinov@ai4microcopy.com)
-
prepare_training_data
(X, y=None, vector_valued=False, **kwargs)¶ Reshapes and converts data to torch tensors for GP analysis
- Parameters
X (ndarray) – Grid indices with dimensions \(c \times N \times M \times L\), where c is equal to the number of coordinates (for example, for xyz coordinates, c = 3)
y (ndarray) – Observations (data points) with dimensions N x M x L
**precision (str) – Choose between single (‘single’) and double (‘double’) precision
- Returns
Pytorch tensors with dimensions \(N \times M \times L \times c\) and \(N \times M \times L\)
-
prepare_test_data
(X, **kwargs)¶ Reshapes and converts data to torch tensors for GP analysis
- Parameters
X (ndarray) – Grid indices with dimensions \(c \times N \times M \times L\) where c is equal to the number of coordinates (for example, for xyz coordinates, c = 3)
**precision (str) – Choose between single (‘single’) and double (‘double’) precision
- Returns
Pytorch tensor with dimensions \(N \times M \times L \times c\)
-
get_grid_indices
(R, dense_x=1.0)¶ Returns full and sparse grid indices for 2D and 3D arrays
- Parameters
R (ndarray) – Sparse grid measurements as 2D or 3D numpy array
dense_x (float) – Determines grid density (can be increased at prediction stage)
-
get_full_grid
(R, extent=None, dense_x=1.0)¶ Creates grid indices for 2D-4D numpy arrays
- Parameters
R (ndarray) – Grid measurements as 2D-4D numpy array
extent (list of lists) – Define multi-dimensional data bounds. For example, for 2D data, the extent parameter is [[xmin, xmax], [ymin, ymax]]
dense_x (float) – Determines grid density (can be increased at prediction stage)
- Returns
Grid indices as numpy array
-
get_sparse_grid
(R, extent=None)¶ Returns sparse grid for sparse image data
- Parameters
R (ndarray) – Sparse grid measurements (missing values are NaNs)
- Returns
Sparse grid indices
-
to_constrained_interval
(state_dict, lscale, amp)¶ Transforms kernel’s unconstrained lenghscale and variance to their constrained domains (intervals)
- Parameters
state_dict (dict) – Kernel’s state dictionary; can be obtained from self.spgr.kernel.state_dict
lscale (list) – List of two lists with lower and upper bound(s) for lenghtscale prior. Number of elements in each list is usually equal to the number of (independent) input dimensions
amp (list) – List with two floats corresponding to lower and upper bounds for variance (square of amplitude) prior
- Returns
Lengthscale and variance in the constrained domain (interval)
-
corrupt_data_xy
(X_true, R_true, prob=0.5, replace_w_zeros=False)¶ Replaces certain % of 2D or 3D image data with NaNs; see gprutils.corrupt_image2d and gprutils.corrupt_image3d
- Parameters
X_true (ndarray) – Grid indices for 2D image or 3D hyperspectral data (3D and 4D numpy arrays, respectively)
R_true (ndarray) – Observations as 2D image or 3D hyperspectral data
prob (float) – Controls % of data to be corrupted (takes values between 0 and 1)
replace_w_zeros (bool) – Corrupts data with zeros instead of NaNs
- Returns
ndarays of grid indices (3D or 4D) and observations (2D or 3D)
-
corrupt_image2d
(X_true, R_true, prob, replace_w_zeros)¶ Replaces certain % of 2D image data with NaNs.
- Parameters
X_true (ndarray) – 3D array with grid indices for 2D image
R_true (ndarray) – 2D image with observations
prob (float) – Controls % of data to be corrupted (takes values between 0 and 1)
replace_w_zeros (bool) – Corrupts data with zeros instead of NaNs
- Returns
3D ndarray of grid coordinates and 2D ndarray of observatons where the part of points is replaced with NaNs.
-
corrupt_image3d
(X_true, R_true, prob, replace_w_zeros)¶ Replaces certain % of 3D hyperspectral data with NaNs. Applies differently in xy and in z dimensions. Specifically, for every corrupted (x, y) point we remove all z values associated with this point.
- Parameters
X_true (ndarray) – 4D array with grid indices for 3D hyperspectral data
R_true (ndarray) – 3D hyperspectral data with observations
prob (float) – Controls % of data to be corrupted (takes values between 0 and 1)
replace_w_zeros (bool) – Corrupts data with zeros instead of NaNs
- Returns
4D ndarray of grid coordinates and 3D ndarray of observatons where certain % of points is replaced with NaNs (note that for every corrupted (x, y) point we remove all z values associated with this point)
-
open_edge_points
(R, R_true, s=6)¶ Opens measured curves at the edges of FOV
- Parameters
R (ndarray) – empty/sparse data
R_true (ndarray) – “ground truth”
s (int) – step value, which determines the density of opened edge points
- Returns
3D ndarray with opened edge points
-
plot_kernel_hyperparams
(hyperparams)¶ Plots evolution of kernel hyperparameters as a function of training steps
- Parameters
hyperparams (dict) – dictionary with kernel hyperparameters (see gpreg.gpr.reconstructor)
-
plot_mixture_hyperparams
(hyperparams)¶ Plots evolution of spectral mixture kernel hyperparameters as a function of training iterations
- Parameters
hyperparams (dict) – dictionary with kernel hyperparameters (see gpreg.skgpr.skreconstructor)
-
plot_raw_data
(raw_data, slice_number, pos, spec_window=2, norm=False, **kwargs)¶ Plots hyperspectral data as 2D image integrated over a certain range of energy/frequency and selected individual spectroscopic curves
- Parameters
raw_data (3D ndarray) – hyperspectral cube (the first two dimensions are xy coordinates and the last dimension is a “spectroscopic” dimension)
slice_number (int) – slice from datacube to visualize
pos (list of lists) – list with [x, y] coordinates of points where single spectroscopic curves will be extracted and visualized
spec_window (int) – window to integrate over in frequency dimension (for 2D “slices”)
**cmap (str) – cmap for 2D image (“slice”) plot
**z_vec (1D ndarray) – spectroscopic measurements values (e.g. frequency, bias)
**z_vec_label (str) – spectroscopic measurements label (e.g. frequency, bias voltage)
**z_vec_units (str) – spectroscopic measurements units (e.g. Hz, V)
-
plot_reconstructed_data2d
(R, mean, save_fig=False, **kwargs)¶ Plots original and GP-reconstructed data for 2D images
- Parameters
R (2D ndarray) – Input image for GP regression
mean (1D ndarray) – Predictive mean, usually an output of gpr.reconstructor or skgpr.skreconstructor. The array is flattened (the actual dimensions are the same as for R)
**cmap (str) – cmap for 2D image plot
**savedir (str) – directory to save output figure
**filepath (str) – name of input file (to create a unique filename for plot)
**sparsity (float) – indicates % of data points removed (used only for figure title)
-
plot_reconstructed_data3d
(R, mean, sd, slice_number, pos, spec_window=2, save_fig=False, **kwargs)¶ Plots original and GP-reconstructed data for 3D images
- Parameters
R (3D ndarray) – Input image for GP regression
mean (1D ndarray) – Predictive mean, usually an output of gpr.reconstructor or skgpr.skreconstructor. The array is flattened (the actual dimensions are the same as for R)
sd (1D ndarray) – Standard deviation (can be flattened; actual dimensions are the same as in R)
slice_number (int) – slice from datacube to visualize
pos (list of lists) – list with [x, y] coordinates of points where single spectroscopic curves will be extracted and visualized
spec_window (int) – window to integrate over in frequency dimension (for 2D “slices”)
**cmap (str) – colormap for 2D image (“slices”) plots
**savedir (str) – directory to save output figure
**sparsity (float) – indicates % of data points removed (used only for figure title)
**filepath (str) – path/name of input file (to create a unique filename for plot)
**z_vec (1D ndarray) – spectroscopic measurements values (e.g. frequency, bias)
**z_vec_label (str) – spectroscopic measurements label (e.g. frequency, bias voltage)
**z_vec_units (str) – spectroscopic measurements units (e.g. Hz, V)
-
plot_exploration_results
(R_all, mean_all, sd_all, R_true, episodes, slice_number, pos, dist_edge, spec_window=2, mask_predictions=False, **kwargs)¶ Plots predictions at different stages (“episodes”) of maximum uncertainty-based sample exploration with GP
- Parameters
R_all (list with ndarrays) – Observed data points at each exploration step
mean_all (list of ndarrays) – Predictive mean at each exploration step
sd_all (list of ndarrays) – Integrated (along energy dimension) SD at each exploration step
R_true (ndarray) – 3D array with ground truth data (full observations) for simulated experiment OR a 3D array of zeros/NaNs for real experiment
episodes (list of ints) – list with the numbers indicating which iteration steps to visualize
slice_number (int) – slice from datacube to visualize
pos (list of lists) – list with [x, y] coordinates of points where single spectroscopic curves will be extracted and visualized
dist_edge (list with two integers) – this should be the same as in exploration analysis
spec_win (int) – window to integrate over in frequency dimension (for 2D “slices”)
mask_predictions (bool) – mask edge regions not used in max uncertainty evaluation in predictive mean plots
**sparsity (float) – indicates % of data points removed (used only for figure title)
**z_vec (1D ndarray) – spectroscopic measurements values (e.g. frequency, bias)
**z_vec_label (str) – spectroscopic measurements label (e.g. frequency, bias voltage)
**z_vec_units (str) – spectroscopic measurements units (e.g. Hz, V)
-
plot_inducing_points
(hyperparams, **kwargs)¶ Plots inducing points evolution during training
-
plot_inducing_points_2d
(hyperparams, **kwargs)¶ Plots 2D trajectories if inducing points
- Parameters
hyperparams (dict) – Dictionary of hyperparameters
**plot_from (int) – plot from specific step
**plot_to (int) – plot till specific step
**slice_step (int) – plot every nth inducing point
-
plot_inducing_points_3d
(hyperparams, **kwargs)¶ Plots 3D trajectories if inducing points during model training
- Parameters
hyperparams (dict) – dictionary of hyperparameters
plot_from (int) – plot from specific step
plot_to (int) – plot till specific step
slice_step (int) – plot every nth inducing point
-
plot_query_points
(inds_all, **kwargs)¶ Plots the exploration path (all the query points) in GP-based Bayesian optimization. Currently supports only 2D data.
- Parameters
inds_all (list) – list of indices
**cmap (str) – colormap