Skip to content

Commit 00e9045

Browse files
author
Pierre Bartet
committed
Fix all selectors
Signed-off-by: Pierre Bartet <[email protected]>
1 parent f92052f commit 00e9045

File tree

5 files changed

+29
-21
lines changed

5 files changed

+29
-21
lines changed

skl2onnx/shape_calculators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from . import cross_decomposition
1111
from . import dict_vectorizer
1212
from . import ensemble_shapes
13+
from . import feature_selection
1314
from . import feature_hasher
1415
from . import flatten
1516
from . import function_transformer
@@ -63,6 +64,7 @@
6364
dict_vectorizer,
6465
ensemble_shapes,
6566
feature_hasher,
67+
feature_selection,
6668
flatten,
6769
function_transformer,
6870
gaussian_process,

skl2onnx/shape_calculators/array_feature_extractor.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,3 @@ def calculate_sklearn_array_feature_extractor(operator):
1616
register_shape_calculator(
1717
"SklearnArrayFeatureExtractor", calculate_sklearn_array_feature_extractor
1818
)
19-
20-
21-
def calculate_sklearn_select_k_best(operator):
22-
check_input_and_output_numbers(operator, output_count_range=1)
23-
i = operator.inputs[0]
24-
N = i.get_first_dimension()
25-
C = operator.raw_operator._get_support_mask().sum()
26-
operator.outputs[0].type = i.type.__class__([N, C])
27-
28-
29-
register_shape_calculator("SklearnSelectKBest", calculate_sklearn_select_k_best)

skl2onnx/shape_calculators/concat.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,3 @@ def more_generic(t1, t2):
9393

9494

9595
register_shape_calculator("SklearnConcat", calculate_sklearn_concat)
96-
register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_concat)
97-
register_shape_calculator("SklearnRFE", calculate_sklearn_concat)
98-
register_shape_calculator("SklearnRFECV", calculate_sklearn_concat)
99-
register_shape_calculator("SklearnSelectFdr", calculate_sklearn_concat)
100-
register_shape_calculator("SklearnSelectFpr", calculate_sklearn_concat)
101-
register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_concat)
102-
register_shape_calculator("SklearnSelectFwe", calculate_sklearn_concat)
103-
# register_shape_calculator("SklearnSelectKBest", calculate_sklearn_concat)
104-
register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_concat)
105-
register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_concat)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
from ..common._registration import register_shape_calculator
5+
from ..common.utils import check_input_and_output_numbers
6+
7+
8+
def calculate_sklearn_select(operator):
9+
check_input_and_output_numbers(operator, output_count_range=1)
10+
i = operator.inputs[0]
11+
N = i.get_first_dimension()
12+
C = operator.raw_operator._get_support_mask().sum()
13+
operator.outputs[0].type = i.type.__class__([N, C])
14+
15+
16+
register_shape_calculator("SklearnGenericUnivariateSelect", calculate_sklearn_select)
17+
register_shape_calculator("SklearnRFE", calculate_sklearn_select)
18+
register_shape_calculator("SklearnRFECV", calculate_sklearn_select)
19+
register_shape_calculator("SklearnSelectFdr", calculate_sklearn_select)
20+
register_shape_calculator("SklearnSelectFpr", calculate_sklearn_select)
21+
register_shape_calculator("SklearnSelectFromModel", calculate_sklearn_select)
22+
register_shape_calculator("SklearnSelectFwe", calculate_sklearn_select)
23+
register_shape_calculator("SklearnSelectKBest", calculate_sklearn_select)
24+
register_shape_calculator("SklearnSelectPercentile", calculate_sklearn_select)
25+
register_shape_calculator("SklearnVarianceThreshold", calculate_sklearn_select)

tests/test_sklearn_feature_selection_converters.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@
3333
class TestSklearnFeatureSelectionConverters(unittest.TestCase):
3434
def test_generic_univariate_select_int(self):
3535
model = GenericUnivariateSelect()
36+
3637
X = np.array(
3738
[[1, 2, 3, 1], [0, 3, 1, 4], [3, 5, 6, 1], [1, 2, 1, 5]], dtype=np.int64
3839
)
3940
y = np.array([0, 1, 0, 1])
4041
model.fit(X, y)
42+
4143
model_onnx = convert_sklearn(
4244
model,
4345
"generic univariate select",

0 commit comments

Comments
 (0)