ikpls.jax_ikpls_alg_2
Contains the PLS Class which implements partial least-squares regression using Improved Kernel PLS Algorithm #2 by Dayal and MacGregor: https://doi.org/10.1002/(SICI)1099-128X(199701)11:1%3C73::AID-CEM435%3E3.0.CO;2-%23
The class is implemented using JAX for end-to-end differentiability. Additionally, JAX allows CPU, GPU, and TPU execution.
Author: Ole-Christian Galbo Engstrøm E-mail: ocge@foss.dk
Classes
|
Implements partial least-squares regression using Improved Kernel PLS Algorithm #2 by Dayal and MacGregor: https://doi.org/10.1002/(SICI)1099-128X(199701)11:1%3C73::AID-CEM435%3E3.0.CO;2-%23. |
- class ikpls.jax_ikpls_alg_2.PLS(center_X: bool = True, center_Y: bool = True, scale_X: bool = True, scale_Y: bool = True, ddof: int = 1, copy: bool = True, dtype: str | type[Any] | dtype | SupportsDType = <class 'jax.numpy.float64'>, differentiable: bool = False, verbose: bool = False)
Bases:
PLSBaseImplements partial least-squares regression using Improved Kernel PLS Algorithm #2 by Dayal and MacGregor: https://doi.org/10.1002/(SICI)1099-128X(199701)11:1%3C73::AID-CEM435%3E3.0.CO;2-%23.
- Parameters:
center_X (bool, default=True) – Whether to center X before fitting by subtracting its row of column-wise means from each row.
center_Y (bool, default=True) – Whether to center Y before fitting by subtracting its row of column-wise means from each row.
scale_X (bool, default=True) – Whether to scale X before fitting by dividing each row with the row of X’s column-wise standard deviations.
scale_Y (bool, default=True) – Whether to scale Y before fitting by dividing each row with the row of Y’s column-wise standard deviations.
ddof (int, default=1) – The delta degrees of freedom to use when computing the sample standard deviation. A value of 0 corresponds to the biased estimate of the sample standard deviation, while a value of 1 corresponds to Bessel’s correction for the sample standard deviation.
copy (bool, default=True) – Whether to copy X and Y in fit before potentially applying centering and scaling. If True, then the data is copied before fitting. If False, and dtype matches the type of X and Y, then centering and scaling is done inplace, modifying both arrays.
dtype (DTypeLike, default=jnp.float64) – The float datatype to use in computation of the PLS algorithm. Using a lower precision than float64 will yield significantly worse results when using an increasing number of components due to propagation of numerical errors.
differentiable (bool, default=False) – Whether to make the implementation end-to-end differentiable. The differentiable version is slightly slower. Results among the two versions are identical. If this is True, fit and stateless_fit will not issue a warning if the residual goes below machine epsilon, and max_stable_components will not be set.
verbose (bool, default=False) – If True, each sub-function will print when it will be JIT compiled. This can be useful to track if recompilation is triggered due to passing inputs with different shapes.
Notes
Any centering and scaling is undone before returning predictions with fit to ensure that predictions are on the original scale. If both centering and scaling are True, then the data is first centered and then scaled.
- fit(X: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, Y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, A: int, weights: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None) Self
Fits Improved Kernel PLS Algorithm #1 on X and Y using A components.
- Parameters:
X (Array of shape (N, K)) – Predictor variables.
Y (Array of shape (N, M) or (N,)) – Response variables.
A (int) – Number of components in the PLS model.
weights (Array of shape (N,) or None, optional, default=None) – Weights for each observation. If None, then all observations are weighted equally.
- A
Number of components in the PLS model.
- Type:
int
- max_stable_components
Maximum number of components that can be used without the residual going below machine epsilon. This is not set if differentiable is True.
- Type:
int
- R_Y
Mapping from number of components to PLS weights matrix to compute scores U directly from original Y. Keys range from 1 to A. Values are arrays of shape (M, n_components) where n_components is the key. Values are computed lazily and cached upon first access. See Notes for more information.
- Type:
Mapping[int, Array]
- X_mean
Mean of X. If centering is not performed, this is None.
- Type:
Array of shape (1, K) or None
- Y_mean
Mean of Y. If centering is not performed, this is None.
- Type:
Array of shape (1, M) or None
- X_std
Sample standard deviation of X. If scaling is not performed, this is None.
- Type:
Array of shape (1, K) or None
- Y_std
Sample standard deviation of Y. If scaling is not performed, this is None.
- Type:
Array of shape (1, M) or None
- Returns:
self – Fitted model.
- Return type:
- Raises:
ValueError – If weights are provided and not all weights are non-negative.
- Warns:
UserWarning. – If at any point during iteration over the number of components A, the residual goes below machine epsilon.
Notes
R_Y is provided for convenience only as it is not required to derive B. Therefore, every value in R_Y is computed lazily and only actually evaluated when accessed by its key for the first time after a call to fit - either by the user or because it is needed by transform. After a value is computed, it is cached for fast future retrieval. R_Y is implemented as a concrete Mapping.
- stateless_fit(X: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, Y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, A: int, weights: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None) Tuple[Array, Array, Array, Array, Array, Array | None, Array | None, Array | None, Array | None]
Fits Improved Kernel PLS Algorithm #1 on X and Y using A components. Returns the internal matrices instead of storing them in the class instance.
- Parameters:
X (Array of shape (N, K)) – Predictor variables. Its dtype will be converted to float64 for reliable results.
Y (Array of shape (N, M) or (N,)) – Response variables. Its dtype will be converted to float64 for reliable results.
A (int) – Number of components in the PLS model.
- Returns:
B (Array of shape (A, K, M)) – PLS regression coefficients tensor.
W (Array of shape (A, K)) – PLS weights matrix for X.
P (Array of shape (A, K)) – PLS loadings matrix for X.
Q (Array of shape (A, M)) – PLS Loadings matrix for Y.
R (Array of shape (A, K)) – PLS weights matrix to compute scores T directly from original X.
X_mean (Array of shape (1, K) or None) – Mean of X. If centering is not performed, this is None.
Y_mean (Array of shape (1, M) or None) – Mean of Y. If centering is not performed, this is None.
X_std (Array of shape (1, K) or None) – Sample standard deviation of X. If scaling is not performed, this is None.
Y_std (Array of shape (1, M) or None) – Sample standard deviation of Y. If scaling is not performed, this is None.
- Raises:
ValueError – If weights are provided and not all weights are non-negative.
- Warns:
UserWarning. – If at any point during iteration over the number of components A, the residual goes below machine epsilon.
Notes
For optimization purposes, the internal representation of all matrices (except B) is transposed from the usual representation.