Utility functions

gprutils.py

Utility functions for the analysis of sparse image and hyperspectral data with Gaussian processes.

Author: Maxim Ziatdinov (email: maxim.ziatdinov@ai4microcopy.com)

prepare_training_data(X, y=None, vector_valued=False, **kwargs)

Reshapes and converts data to torch tensors for GP analysis

Parameters
  • X (ndarray) – Grid indices with dimensions \(c \times N \times M \times L\), where c is equal to the number of coordinates (for example, for xyz coordinates, c = 3)

  • y (ndarray) – Observations (data points) with dimensions N x M x L

  • **precision (str) – Choose between single (‘single’) and double (‘double’) precision

Returns

Pytorch tensors with dimensions \(N \times M \times L \times c\) and \(N \times M \times L\)

prepare_test_data(X, **kwargs)

Reshapes and converts data to torch tensors for GP analysis

Parameters
  • X (ndarray) – Grid indices with dimensions \(c \times N \times M \times L\) where c is equal to the number of coordinates (for example, for xyz coordinates, c = 3)

  • **precision (str) – Choose between single (‘single’) and double (‘double’) precision

Returns

Pytorch tensor with dimensions \(N \times M \times L \times c\)

get_grid_indices(R, dense_x=1.0)

Returns full and sparse grid indices for 2D and 3D arrays

Parameters
  • R (ndarray) – Sparse grid measurements as 2D or 3D numpy array

  • dense_x (float) – Determines grid density (can be increased at prediction stage)

get_full_grid(R, extent=None, dense_x=1.0)

Creates grid indices for 2D-4D numpy arrays

Parameters
  • R (ndarray) – Grid measurements as 2D-4D numpy array

  • extent (list of lists) – Define multi-dimensional data bounds. For example, for 2D data, the extent parameter is [[xmin, xmax], [ymin, ymax]]

  • dense_x (float) – Determines grid density (can be increased at prediction stage)

Returns

Grid indices as numpy array

get_sparse_grid(R, extent=None)

Returns sparse grid for sparse image data

Parameters

R (ndarray) – Sparse grid measurements (missing values are NaNs)

Returns

Sparse grid indices

to_constrained_interval(state_dict, lscale, amp)

Transforms kernel’s unconstrained lenghscale and variance to their constrained domains (intervals)

Parameters
  • state_dict (dict) – Kernel’s state dictionary; can be obtained from self.spgr.kernel.state_dict

  • lscale (list) – List of two lists with lower and upper bound(s) for lenghtscale prior. Number of elements in each list is usually equal to the number of (independent) input dimensions

  • amp (list) – List with two floats corresponding to lower and upper bounds for variance (square of amplitude) prior

Returns

Lengthscale and variance in the constrained domain (interval)

corrupt_data_xy(X_true, R_true, prob=0.5, replace_w_zeros=False)

Replaces certain % of 2D or 3D image data with NaNs; see gprutils.corrupt_image2d and gprutils.corrupt_image3d

Parameters
  • X_true (ndarray) – Grid indices for 2D image or 3D hyperspectral data (3D and 4D numpy arrays, respectively)

  • R_true (ndarray) – Observations as 2D image or 3D hyperspectral data

  • prob (float) – Controls % of data to be corrupted (takes values between 0 and 1)

  • replace_w_zeros (bool) – Corrupts data with zeros instead of NaNs

Returns

ndarays of grid indices (3D or 4D) and observations (2D or 3D)

corrupt_image2d(X_true, R_true, prob, replace_w_zeros)

Replaces certain % of 2D image data with NaNs.

Parameters
  • X_true (ndarray) – 3D array with grid indices for 2D image

  • R_true (ndarray) – 2D image with observations

  • prob (float) – Controls % of data to be corrupted (takes values between 0 and 1)

  • replace_w_zeros (bool) – Corrupts data with zeros instead of NaNs

Returns

3D ndarray of grid coordinates and 2D ndarray of observatons where the part of points is replaced with NaNs.

corrupt_image3d(X_true, R_true, prob, replace_w_zeros)

Replaces certain % of 3D hyperspectral data with NaNs. Applies differently in xy and in z dimensions. Specifically, for every corrupted (x, y) point we remove all z values associated with this point.

Parameters
  • X_true (ndarray) – 4D array with grid indices for 3D hyperspectral data

  • R_true (ndarray) – 3D hyperspectral data with observations

  • prob (float) – Controls % of data to be corrupted (takes values between 0 and 1)

  • replace_w_zeros (bool) – Corrupts data with zeros instead of NaNs

Returns

4D ndarray of grid coordinates and 3D ndarray of observatons where certain % of points is replaced with NaNs (note that for every corrupted (x, y) point we remove all z values associated with this point)

open_edge_points(R, R_true, s=6)

Opens measured curves at the edges of FOV

Parameters
  • R (ndarray) – empty/sparse data

  • R_true (ndarray) – “ground truth”

  • s (int) – step value, which determines the density of opened edge points

Returns

3D ndarray with opened edge points

plot_kernel_hyperparams(hyperparams)

Plots evolution of kernel hyperparameters as a function of training steps

Parameters

hyperparams (dict) – dictionary with kernel hyperparameters (see gpreg.gpr.reconstructor)

plot_mixture_hyperparams(hyperparams)

Plots evolution of spectral mixture kernel hyperparameters as a function of training iterations

Parameters

hyperparams (dict) – dictionary with kernel hyperparameters (see gpreg.skgpr.skreconstructor)

plot_raw_data(raw_data, slice_number, pos, spec_window=2, norm=False, **kwargs)

Plots hyperspectral data as 2D image integrated over a certain range of energy/frequency and selected individual spectroscopic curves

Parameters
  • raw_data (3D ndarray) – hyperspectral cube (the first two dimensions are xy coordinates and the last dimension is a “spectroscopic” dimension)

  • slice_number (int) – slice from datacube to visualize

  • pos (list of lists) – list with [x, y] coordinates of points where single spectroscopic curves will be extracted and visualized

  • spec_window (int) – window to integrate over in frequency dimension (for 2D “slices”)

  • **cmap (str) – cmap for 2D image (“slice”) plot

  • **z_vec (1D ndarray) – spectroscopic measurements values (e.g. frequency, bias)

  • **z_vec_label (str) – spectroscopic measurements label (e.g. frequency, bias voltage)

  • **z_vec_units (str) – spectroscopic measurements units (e.g. Hz, V)

plot_reconstructed_data2d(R, mean, save_fig=False, **kwargs)

Plots original and GP-reconstructed data for 2D images

Parameters
  • R (2D ndarray) – Input image for GP regression

  • mean (1D ndarray) – Predictive mean, usually an output of gpr.reconstructor or skgpr.skreconstructor. The array is flattened (the actual dimensions are the same as for R)

  • **cmap (str) – cmap for 2D image plot

  • **savedir (str) – directory to save output figure

  • **filepath (str) – name of input file (to create a unique filename for plot)

  • **sparsity (float) – indicates % of data points removed (used only for figure title)

plot_reconstructed_data3d(R, mean, sd, slice_number, pos, spec_window=2, save_fig=False, **kwargs)

Plots original and GP-reconstructed data for 3D images

Parameters
  • R (3D ndarray) – Input image for GP regression

  • mean (1D ndarray) – Predictive mean, usually an output of gpr.reconstructor or skgpr.skreconstructor. The array is flattened (the actual dimensions are the same as for R)

  • sd (1D ndarray) – Standard deviation (can be flattened; actual dimensions are the same as in R)

  • slice_number (int) – slice from datacube to visualize

  • pos (list of lists) – list with [x, y] coordinates of points where single spectroscopic curves will be extracted and visualized

  • spec_window (int) – window to integrate over in frequency dimension (for 2D “slices”)

  • **cmap (str) – colormap for 2D image (“slices”) plots

  • **savedir (str) – directory to save output figure

  • **sparsity (float) – indicates % of data points removed (used only for figure title)

  • **filepath (str) – path/name of input file (to create a unique filename for plot)

  • **z_vec (1D ndarray) – spectroscopic measurements values (e.g. frequency, bias)

  • **z_vec_label (str) – spectroscopic measurements label (e.g. frequency, bias voltage)

  • **z_vec_units (str) – spectroscopic measurements units (e.g. Hz, V)

plot_exploration_results(R_all, mean_all, sd_all, R_true, episodes, slice_number, pos, dist_edge, spec_window=2, mask_predictions=False, **kwargs)

Plots predictions at different stages (“episodes”) of maximum uncertainty-based sample exploration with GP

Parameters
  • R_all (list with ndarrays) – Observed data points at each exploration step

  • mean_all (list of ndarrays) – Predictive mean at each exploration step

  • sd_all (list of ndarrays) – Integrated (along energy dimension) SD at each exploration step

  • R_true (ndarray) – 3D array with ground truth data (full observations) for simulated experiment OR a 3D array of zeros/NaNs for real experiment

  • episodes (list of ints) – list with the numbers indicating which iteration steps to visualize

  • slice_number (int) – slice from datacube to visualize

  • pos (list of lists) – list with [x, y] coordinates of points where single spectroscopic curves will be extracted and visualized

  • dist_edge (list with two integers) – this should be the same as in exploration analysis

  • spec_win (int) – window to integrate over in frequency dimension (for 2D “slices”)

  • mask_predictions (bool) – mask edge regions not used in max uncertainty evaluation in predictive mean plots

  • **sparsity (float) – indicates % of data points removed (used only for figure title)

  • **z_vec (1D ndarray) – spectroscopic measurements values (e.g. frequency, bias)

  • **z_vec_label (str) – spectroscopic measurements label (e.g. frequency, bias voltage)

  • **z_vec_units (str) – spectroscopic measurements units (e.g. Hz, V)

plot_inducing_points(hyperparams, **kwargs)

Plots inducing points evolution during training

plot_inducing_points_2d(hyperparams, **kwargs)

Plots 2D trajectories if inducing points

Parameters
  • hyperparams (dict) – Dictionary of hyperparameters

  • **plot_from (int) – plot from specific step

  • **plot_to (int) – plot till specific step

  • **slice_step (int) – plot every nth inducing point

plot_inducing_points_3d(hyperparams, **kwargs)

Plots 3D trajectories if inducing points during model training

Parameters
  • hyperparams (dict) – dictionary of hyperparameters

  • plot_from (int) – plot from specific step

  • plot_to (int) – plot till specific step

  • slice_step (int) – plot every nth inducing point

plot_query_points(inds_all, **kwargs)

Plots the exploration path (all the query points) in GP-based Bayesian optimization. Currently supports only 2D data.

Parameters
  • inds_all (list) – list of indices

  • **cmap (str) – colormap