Skip to content

Commit ccf23dc

Browse files
authored
Fix PCA sign flip (#7331)
Close #4560 This PR implemented PCA sign flip algorithm such that the max absolute value in each row is always positive in `components`, leaving `trans_input` unchanged, just like `sklearn.decomposition.PCA`. In `_mg` version of PCA, `components` is not chunked. Thus, the `_mg` version can reuse sign flipping from single-GPU version (by setting `stream = streams[0]`). In this PR, `cuml\decomposition\sign_flip_mg.hpp` and `cpp\src\pca\sign_flip_mg.cu` are not in use (but the files are not removed). Authors: - https://github.com/zhuxr11 - Divye Gala (https://github.com/divyegala) - Simon Adorf (https://github.com/csadorf) - Jim Crist-Harif (https://github.com/jcrist) Approvers: - Simon Adorf (https://github.com/csadorf) - Divye Gala (https://github.com/divyegala) URL: #7331
1 parent da72b46 commit ccf23dc

File tree

27 files changed

+713
-139
lines changed

27 files changed

+713
-139
lines changed

cpp/include/cuml/decomposition/pca.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2018-2021, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2018-2025, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -21,7 +21,8 @@ void pcaFit(raft::handle_t& handle,
2121
float* singular_vals,
2222
float* mu,
2323
float* noise_vars,
24-
const paramsPCA& prms);
24+
const paramsPCA& prms,
25+
bool flip_signs_based_on_U);
2526
void pcaFit(raft::handle_t& handle,
2627
double* input,
2728
double* components,
@@ -30,7 +31,8 @@ void pcaFit(raft::handle_t& handle,
3031
double* singular_vals,
3132
double* mu,
3233
double* noise_vars,
33-
const paramsPCA& prms);
34+
const paramsPCA& prms,
35+
bool flip_signs_based_on_U);
3436
void pcaFitTransform(raft::handle_t& handle,
3537
float* input,
3638
float* trans_input,
@@ -40,7 +42,8 @@ void pcaFitTransform(raft::handle_t& handle,
4042
float* singular_vals,
4143
float* mu,
4244
float* noise_vars,
43-
const paramsPCA& prms);
45+
const paramsPCA& prms,
46+
bool flip_signs_based_on_U);
4447
void pcaFitTransform(raft::handle_t& handle,
4548
double* input,
4649
double* trans_input,
@@ -50,7 +53,8 @@ void pcaFitTransform(raft::handle_t& handle,
5053
double* singular_vals,
5154
double* mu,
5255
double* noise_vars,
53-
const paramsPCA& prms);
56+
const paramsPCA& prms,
57+
bool flip_signs_based_on_U);
5458
void pcaInverseTransform(raft::handle_t& handle,
5559
float* trans_input,
5660
float* components,

cpp/include/cuml/decomposition/pca_mg.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2020-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -27,6 +27,8 @@ namespace opg {
2727
* @param[out] noise_vars: variance of the noise
2828
* @param[in] prms: data structure that includes all the parameters from input size to algorithm
2929
* @param[in] verbose
30+
* @param[in] flip_signs_based_on_U: Whether to use U-based decision for sign flipping (for sklearn
31+
* < 1.5)
3032
*/
3133
void fit(raft::handle_t& handle,
3234
std::vector<MLCommon::Matrix::Data<float>*>& input_data,
@@ -38,7 +40,8 @@ void fit(raft::handle_t& handle,
3840
float* mu,
3941
float* noise_vars,
4042
paramsPCAMG prms,
41-
bool verbose = false);
43+
bool verbose = false,
44+
bool flip_signs_based_on_U = false);
4245

4346
void fit(raft::handle_t& handle,
4447
std::vector<MLCommon::Matrix::Data<double>*>& input_data,
@@ -50,7 +53,8 @@ void fit(raft::handle_t& handle,
5053
double* mu,
5154
double* noise_vars,
5255
paramsPCAMG prms,
53-
bool verbose = false);
56+
bool verbose = false,
57+
bool flip_signs_based_on_U = false);
5458

5559
/**
5660
* @brief performs MNMG fit and transform operation for the pca
@@ -67,6 +71,8 @@ void fit(raft::handle_t& handle,
6771
* @param[out] noise_vars: variance of the noise
6872
* @param[in] prms: data structure that includes all the parameters from input size to algorithm
6973
* @param[in] verbose
74+
* @param[in] flip_signs_based_on_U: Whether to use U-based decision for sign flipping (for sklearn
75+
* < 1.5)
7076
*/
7177
void fit_transform(raft::handle_t& handle,
7278
MLCommon::Matrix::RankSizePair** rank_sizes,
@@ -80,7 +86,8 @@ void fit_transform(raft::handle_t& handle,
8086
float* mu,
8187
float* noise_vars,
8288
paramsPCAMG prms,
83-
bool verbose);
89+
bool verbose,
90+
bool flip_signs_based_on_U);
8491

8592
void fit_transform(raft::handle_t& handle,
8693
MLCommon::Matrix::RankSizePair** rank_sizes,
@@ -94,7 +101,8 @@ void fit_transform(raft::handle_t& handle,
94101
double* mu,
95102
double* noise_vars,
96103
paramsPCAMG prms,
97-
bool verbose);
104+
bool verbose,
105+
bool flip_signs_based_on_U);
98106

99107
/**
100108
* @brief performs MNMG transform operation for the pca

cpp/include/cuml/decomposition/sign_flip_mg.hpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2020-2022, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -13,6 +13,43 @@ namespace ML {
1313
namespace PCA {
1414
namespace opg {
1515

16+
/**
17+
* @brief sign flip for PCA and tSVD. This is used to stabilize the sign of column major eigen
18+
* vectors
19+
* @param[in] handle: the internal cuml handle object
20+
* @param[in] input_data: input matrix that will be used to determine the sign.
21+
* @param[in] input_desc: MNMG description of the input
22+
* @param[out] components: components matrix.
23+
* @param[in] n_samples: number of rows of input matrix
24+
* @param[in] n_features: number of columns of input/components matrix
25+
* @param[in] n_components: number of rows of components matrix
26+
* @param[in] streams: cuda streams
27+
* @param[in] n_stream: number of streams
28+
* @param[in] center: whether to center input_data by columns
29+
* @{
30+
*/
31+
void sign_flip_components_u(raft::handle_t& handle,
32+
std::vector<MLCommon::Matrix::Data<float>*>& input_data,
33+
MLCommon::Matrix::PartDescriptor& input_desc,
34+
float* components,
35+
std::size_t n_samples,
36+
std::size_t n_features,
37+
std::size_t n_components,
38+
cudaStream_t* streams,
39+
std::uint32_t n_stream,
40+
bool center);
41+
42+
void sign_flip_components_u(raft::handle_t& handle,
43+
std::vector<MLCommon::Matrix::Data<double>*>& input_data,
44+
MLCommon::Matrix::PartDescriptor& input_desc,
45+
double* components,
46+
std::size_t n_samples,
47+
std::size_t n_features,
48+
std::size_t n_components,
49+
cudaStream_t* streams,
50+
std::uint32_t n_stream,
51+
bool center);
52+
1653
/**
1754
* @brief sign flip for PCA and tSVD. This is used to stabilize the sign of column major eigen
1855
* vectors

cpp/include/cuml/decomposition/tsvd.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2018-2021, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2018-2025, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -17,12 +17,14 @@ void tsvdFit(raft::handle_t& handle,
1717
float* input,
1818
float* components,
1919
float* singular_vals,
20-
const paramsTSVD& prms);
20+
const paramsTSVD& prms,
21+
bool flip_signs_based_on_U);
2122
void tsvdFit(raft::handle_t& handle,
2223
double* input,
2324
double* components,
2425
double* singular_vals,
25-
const paramsTSVD& prms);
26+
const paramsTSVD& prms,
27+
bool flip_signs_based_on_U);
2628
void tsvdInverseTransform(raft::handle_t& handle,
2729
float* trans_input,
2830
float* components,
@@ -50,14 +52,16 @@ void tsvdFitTransform(raft::handle_t& handle,
5052
float* explained_var,
5153
float* explained_var_ratio,
5254
float* singular_vals,
53-
const paramsTSVD& prms);
55+
const paramsTSVD& prms,
56+
bool flip_signs_based_on_U);
5457
void tsvdFitTransform(raft::handle_t& handle,
5558
double* input,
5659
double* trans_input,
5760
double* components,
5861
double* explained_var,
5962
double* explained_var_ratio,
6063
double* singular_vals,
61-
const paramsTSVD& prms);
64+
const paramsTSVD& prms,
65+
bool flip_signs_based_on_U);
6266

6367
} // namespace ML

cpp/include/cuml/decomposition/tsvd_mg.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2020-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -24,6 +24,8 @@ namespace opg {
2424
* @param[out] singular_vals: singular values of the data
2525
* @param[in] prms: data structure that includes all the parameters from input size to algorithm
2626
* @param[in] verbose
27+
* @param[in] flip_signs_based_on_U: Whether to use U-based decision for sign flipping (for sklearn
28+
* < 1.5)
2729
*/
2830
void fit(raft::handle_t& handle,
2931
MLCommon::Matrix::RankSizePair** rank_sizes,
@@ -32,7 +34,8 @@ void fit(raft::handle_t& handle,
3234
float* components,
3335
float* singular_vals,
3436
paramsTSVDMG& prms,
35-
bool verbose = false);
37+
bool verbose = false,
38+
bool flip_signs_based_on_U = false);
3639

3740
void fit(raft::handle_t& handle,
3841
MLCommon::Matrix::RankSizePair** rank_sizes,
@@ -41,7 +44,8 @@ void fit(raft::handle_t& handle,
4144
double* components,
4245
double* singular_vals,
4346
paramsTSVDMG& prms,
44-
bool verbose = false);
47+
bool verbose = false,
48+
bool flip_signs_based_on_U = false);
4549

4650
/**
4751
* @brief performs MNMG fit and transform operation for the tsvd.
@@ -56,6 +60,8 @@ void fit(raft::handle_t& handle,
5660
* @param[out] singular_vals: singular values of the data
5761
* @param[in] prms: data structure that includes all the parameters from input size to algorithm
5862
* @param[in] verbose
63+
* @param[in] flip_signs_based_on_U: Whether to use U-based decision for sign flipping (for sklearn
64+
* < 1.5)
5965
*/
6066
void fit_transform(raft::handle_t& handle,
6167
std::vector<MLCommon::Matrix::Data<float>*>& input_data,
@@ -67,7 +73,8 @@ void fit_transform(raft::handle_t& handle,
6773
float* explained_var_ratio,
6874
float* singular_vals,
6975
paramsTSVDMG& prms,
70-
bool verbose);
76+
bool verbose,
77+
bool flip_signs_based_on_U);
7178

7279
void fit_transform(raft::handle_t& handle,
7380
std::vector<MLCommon::Matrix::Data<double>*>& input_data,
@@ -79,7 +86,8 @@ void fit_transform(raft::handle_t& handle,
7986
double* explained_var_ratio,
8087
double* singular_vals,
8188
paramsTSVDMG& prms,
82-
bool verbose);
89+
bool verbose,
90+
bool flip_signs_based_on_U);
8391

8492
/**
8593
* @brief performs MNMG transform operation for the tsvd.

cpp/src/pca/pca.cu

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2018-2025, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -19,7 +19,8 @@ void pcaFit(raft::handle_t& handle,
1919
float* singular_vals,
2020
float* mu,
2121
float* noise_vars,
22-
const paramsPCA& prms)
22+
const paramsPCA& prms,
23+
bool flip_signs_based_on_U = false)
2324
{
2425
pcaFit(handle,
2526
input,
@@ -30,7 +31,8 @@ void pcaFit(raft::handle_t& handle,
3031
mu,
3132
noise_vars,
3233
prms,
33-
handle.get_stream());
34+
handle.get_stream(),
35+
flip_signs_based_on_U);
3436
}
3537

3638
void pcaFit(raft::handle_t& handle,
@@ -41,7 +43,8 @@ void pcaFit(raft::handle_t& handle,
4143
double* singular_vals,
4244
double* mu,
4345
double* noise_vars,
44-
const paramsPCA& prms)
46+
const paramsPCA& prms,
47+
bool flip_signs_based_on_U = false)
4548
{
4649
pcaFit(handle,
4750
input,
@@ -52,7 +55,8 @@ void pcaFit(raft::handle_t& handle,
5255
mu,
5356
noise_vars,
5457
prms,
55-
handle.get_stream());
58+
handle.get_stream(),
59+
flip_signs_based_on_U);
5660
}
5761

5862
void pcaFitTransform(raft::handle_t& handle,
@@ -64,7 +68,8 @@ void pcaFitTransform(raft::handle_t& handle,
6468
float* singular_vals,
6569
float* mu,
6670
float* noise_vars,
67-
const paramsPCA& prms)
71+
const paramsPCA& prms,
72+
bool flip_signs_based_on_U = false)
6873
{
6974
pcaFitTransform(handle,
7075
input,
@@ -76,7 +81,8 @@ void pcaFitTransform(raft::handle_t& handle,
7681
mu,
7782
noise_vars,
7883
prms,
79-
handle.get_stream());
84+
handle.get_stream(),
85+
flip_signs_based_on_U);
8086
}
8187

8288
void pcaFitTransform(raft::handle_t& handle,
@@ -88,7 +94,8 @@ void pcaFitTransform(raft::handle_t& handle,
8894
double* singular_vals,
8995
double* mu,
9096
double* noise_vars,
91-
const paramsPCA& prms)
97+
const paramsPCA& prms,
98+
bool flip_signs_based_on_U = false)
9299
{
93100
pcaFitTransform(handle,
94101
input,
@@ -100,7 +107,8 @@ void pcaFitTransform(raft::handle_t& handle,
100107
mu,
101108
noise_vars,
102109
prms,
103-
handle.get_stream());
110+
handle.get_stream(),
111+
flip_signs_based_on_U);
104112
}
105113

106114
void pcaInverseTransform(raft::handle_t& handle,

0 commit comments

Comments
 (0)