# Utility functions¶

## gprutils.py¶

Utility functions for the analysis of sparse image and hyperspectral data with Gaussian processes.

Author: Maxim Ziatdinov (email: maxim.ziatdinov@ai4microcopy.com)

prepare_training_data(X, y=None, vector_valued=False, **kwargs)

Reshapes and converts data to torch tensors for GP analysis

Parameters
• X (ndarray) – Grid indices with dimensions $$c \times N \times M \times L$$, where c is equal to the number of coordinates (for example, for xyz coordinates, c = 3)

• y (ndarray) – Observations (data points) with dimensions N x M x L

• **precision (str) – Choose between single (‘single’) and double (‘double’) precision

Returns

Pytorch tensors with dimensions $$N \times M \times L \times c$$ and $$N \times M \times L$$

prepare_test_data(X, **kwargs)

Reshapes and converts data to torch tensors for GP analysis

Parameters
• X (ndarray) – Grid indices with dimensions $$c \times N \times M \times L$$ where c is equal to the number of coordinates (for example, for xyz coordinates, c = 3)

• **precision (str) – Choose between single (‘single’) and double (‘double’) precision

Returns

Pytorch tensor with dimensions $$N \times M \times L \times c$$

get_grid_indices(R, dense_x=1.0)

Returns full and sparse grid indices for 2D and 3D arrays

Parameters
• R (ndarray) – Sparse grid measurements as 2D or 3D numpy array

• dense_x (float) – Determines grid density (can be increased at prediction stage)

get_full_grid(R, extent=None, dense_x=1.0)

Creates grid indices for 2D-4D numpy arrays

Parameters
• R (ndarray) – Grid measurements as 2D-4D numpy array

• extent (list of lists) – Define multi-dimensional data bounds. For example, for 2D data, the extent parameter is [[xmin, xmax], [ymin, ymax]]

• dense_x (float) – Determines grid density (can be increased at prediction stage)

Returns

Grid indices as numpy array

get_sparse_grid(R, extent=None)

Returns sparse grid for sparse image data

Parameters

R (ndarray) – Sparse grid measurements (missing values are NaNs)

Returns

Sparse grid indices

to_constrained_interval(state_dict, lscale, amp)

Transforms kernel’s unconstrained lenghscale and variance to their constrained domains (intervals)

Parameters
• state_dict (dict) – Kernel’s state dictionary; can be obtained from self.spgr.kernel.state_dict

• lscale (list) – List of two lists with lower and upper bound(s) for lenghtscale prior. Number of elements in each list is usually equal to the number of (independent) input dimensions

• amp (list) – List with two floats corresponding to lower and upper bounds for variance (square of amplitude) prior

Returns

Lengthscale and variance in the constrained domain (interval)

corrupt_data_xy(X_true, R_true, prob=0.5, replace_w_zeros=False)

Replaces certain % of 2D or 3D image data with NaNs; see gprutils.corrupt_image2d and gprutils.corrupt_image3d

Parameters
• X_true (ndarray) – Grid indices for 2D image or 3D hyperspectral data (3D and 4D numpy arrays, respectively)

• R_true (ndarray) – Observations as 2D image or 3D hyperspectral data

• prob (float) – Controls % of data to be corrupted (takes values between 0 and 1)

• replace_w_zeros (bool) – Corrupts data with zeros instead of NaNs

Returns

ndarays of grid indices (3D or 4D) and observations (2D or 3D)

corrupt_image2d(X_true, R_true, prob, replace_w_zeros)

Replaces certain % of 2D image data with NaNs.

Parameters
• X_true (ndarray) – 3D array with grid indices for 2D image

• R_true (ndarray) – 2D image with observations

• prob (float) – Controls % of data to be corrupted (takes values between 0 and 1)

• replace_w_zeros (bool) – Corrupts data with zeros instead of NaNs

Returns

3D ndarray of grid coordinates and 2D ndarray of observatons where the part of points is replaced with NaNs.

corrupt_image3d(X_true, R_true, prob, replace_w_zeros)

Replaces certain % of 3D hyperspectral data with NaNs. Applies differently in xy and in z dimensions. Specifically, for every corrupted (x, y) point we remove all z values associated with this point.

Parameters
• X_true (ndarray) – 4D array with grid indices for 3D hyperspectral data

• R_true (ndarray) – 3D hyperspectral data with observations

• prob (float) – Controls % of data to be corrupted (takes values between 0 and 1)

• replace_w_zeros (bool) – Corrupts data with zeros instead of NaNs

Returns

4D ndarray of grid coordinates and 3D ndarray of observatons where certain % of points is replaced with NaNs (note that for every corrupted (x, y) point we remove all z values associated with this point)

open_edge_points(R, R_true, s=6)

Opens measured curves at the edges of FOV

Parameters
• R (ndarray) – empty/sparse data

• R_true (ndarray) – “ground truth”

• s (int) – step value, which determines the density of opened edge points

Returns

3D ndarray with opened edge points

plot_kernel_hyperparams(hyperparams)

Plots evolution of kernel hyperparameters as a function of training steps

Parameters

hyperparams (dict) – dictionary with kernel hyperparameters (see gpreg.gpr.reconstructor)

plot_mixture_hyperparams(hyperparams)

Plots evolution of spectral mixture kernel hyperparameters as a function of training iterations

Parameters

hyperparams (dict) – dictionary with kernel hyperparameters (see gpreg.skgpr.skreconstructor)

plot_raw_data(raw_data, slice_number, pos, spec_window=2, norm=False, **kwargs)

Plots hyperspectral data as 2D image integrated over a certain range of energy/frequency and selected individual spectroscopic curves

Parameters
• raw_data (3D ndarray) – hyperspectral cube (the first two dimensions are xy coordinates and the last dimension is a “spectroscopic” dimension)

• slice_number (int) – slice from datacube to visualize

• pos (list of lists) – list with [x, y] coordinates of points where single spectroscopic curves will be extracted and visualized

• spec_window (int) – window to integrate over in frequency dimension (for 2D “slices”)

• **cmap (str) – cmap for 2D image (“slice”) plot

• **z_vec (1D ndarray) – spectroscopic measurements values (e.g. frequency, bias)

• **z_vec_label (str) – spectroscopic measurements label (e.g. frequency, bias voltage)

• **z_vec_units (str) – spectroscopic measurements units (e.g. Hz, V)

plot_reconstructed_data2d(R, mean, save_fig=False, **kwargs)

Plots original and GP-reconstructed data for 2D images

Parameters
• R (2D ndarray) – Input image for GP regression

• mean (1D ndarray) – Predictive mean, usually an output of gpr.reconstructor or skgpr.skreconstructor. The array is flattened (the actual dimensions are the same as for R)

• **cmap (str) – cmap for 2D image plot

• **savedir (str) – directory to save output figure

• **filepath (str) – name of input file (to create a unique filename for plot)

• **sparsity (float) – indicates % of data points removed (used only for figure title)

plot_reconstructed_data3d(R, mean, sd, slice_number, pos, spec_window=2, save_fig=False, **kwargs)

Plots original and GP-reconstructed data for 3D images

Parameters
• R (3D ndarray) – Input image for GP regression

• mean (1D ndarray) – Predictive mean, usually an output of gpr.reconstructor or skgpr.skreconstructor. The array is flattened (the actual dimensions are the same as for R)

• sd (1D ndarray) – Standard deviation (can be flattened; actual dimensions are the same as in R)

• slice_number (int) – slice from datacube to visualize

• pos (list of lists) – list with [x, y] coordinates of points where single spectroscopic curves will be extracted and visualized

• spec_window (int) – window to integrate over in frequency dimension (for 2D “slices”)

• **cmap (str) – colormap for 2D image (“slices”) plots

• **savedir (str) – directory to save output figure

• **sparsity (float) – indicates % of data points removed (used only for figure title)

• **filepath (str) – path/name of input file (to create a unique filename for plot)

• **z_vec (1D ndarray) – spectroscopic measurements values (e.g. frequency, bias)

• **z_vec_label (str) – spectroscopic measurements label (e.g. frequency, bias voltage)

• **z_vec_units (str) – spectroscopic measurements units (e.g. Hz, V)

plot_exploration_results(R_all, mean_all, sd_all, R_true, episodes, slice_number, pos, dist_edge, spec_window=2, mask_predictions=False, **kwargs)

Plots predictions at different stages (“episodes”) of maximum uncertainty-based sample exploration with GP

Parameters
• R_all (list with ndarrays) – Observed data points at each exploration step

• mean_all (list of ndarrays) – Predictive mean at each exploration step

• sd_all (list of ndarrays) – Integrated (along energy dimension) SD at each exploration step

• R_true (ndarray) – 3D array with ground truth data (full observations) for simulated experiment OR a 3D array of zeros/NaNs for real experiment

• episodes (list of ints) – list with the numbers indicating which iteration steps to visualize

• slice_number (int) – slice from datacube to visualize

• pos (list of lists) – list with [x, y] coordinates of points where single spectroscopic curves will be extracted and visualized

• dist_edge (list with two integers) – this should be the same as in exploration analysis

• spec_win (int) – window to integrate over in frequency dimension (for 2D “slices”)

• mask_predictions (bool) – mask edge regions not used in max uncertainty evaluation in predictive mean plots

• **sparsity (float) – indicates % of data points removed (used only for figure title)

• **z_vec (1D ndarray) – spectroscopic measurements values (e.g. frequency, bias)

• **z_vec_label (str) – spectroscopic measurements label (e.g. frequency, bias voltage)

• **z_vec_units (str) – spectroscopic measurements units (e.g. Hz, V)

plot_inducing_points(hyperparams, **kwargs)

Plots inducing points evolution during training

plot_inducing_points_2d(hyperparams, **kwargs)

Plots 2D trajectories if inducing points

Parameters
• hyperparams (dict) – Dictionary of hyperparameters

• **plot_from (int) – plot from specific step

• **plot_to (int) – plot till specific step

• **slice_step (int) – plot every nth inducing point

plot_inducing_points_3d(hyperparams, **kwargs)

Plots 3D trajectories if inducing points during model training

Parameters
• hyperparams (dict) – dictionary of hyperparameters

• plot_from (int) – plot from specific step

• plot_to (int) – plot till specific step

• slice_step (int) – plot every nth inducing point

plot_query_points(inds_all, **kwargs)

Plots the exploration path (all the query points) in GP-based Bayesian optimization. Currently supports only 2D data.

Parameters
• inds_all (list) – list of indices

• **cmap (str) – colormap