Skip to content

Commit 469a18b

Browse files
Pierre-BartetPierre Bartet
andauthored
Fix empty column selector (#1159)
* Add a test to trigger the bug Signed-off-by: Pierre Bartet <[email protected]> * Fix _parse.py Signed-off-by: Pierre Bartet <[email protected]> * Fix formatting Signed-off-by: Pierre Bartet <[email protected]> --------- Signed-off-by: Pierre Bartet <[email protected]> Co-authored-by: Pierre Bartet <[email protected]>
1 parent e0799d3 commit 469a18b

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

skl2onnx/_parse.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,11 @@ def _parse_sklearn_column_transformer(scope, model, inputs, custom_parsers=None)
362362
elif isinstance(column_indices, (int, str)):
363363
column_indices = [column_indices]
364364
names = get_column_indices(column_indices, inputs, multiple=True)
365+
366+
# Skip transforms which apply to no columns at all
367+
if len(names) == 0:
368+
continue
369+
365370
transform_inputs = []
366371
for onnx_var, onnx_is in names.items():
367372
tr_inputs = _fetch_input_slice(scope, [inputs[onnx_var]], onnx_is)

tests/test_sklearn_pipeline.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,30 @@ def test_pipeline_make_column_selector(self):
13351335
)
13361336
assert_almost_equal(expected, got[0])
13371337

1338+
@unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available")
1339+
@unittest.skipIf(not check_scikit_version(), reason="Scikit 0.21 too old")
1340+
@ignore_warnings(category=(FutureWarning, UserWarning))
1341+
def test_pipeline_empty_make_column_selector(self):
1342+
X = pandas.DataFrame({"city": ["London", "London", "Paris", "Sallisaw"]})
1343+
1344+
ct = make_column_transformer(
1345+
(StandardScaler(), make_column_selector(dtype_include=numpy.number)),
1346+
(OneHotEncoder(), make_column_selector(dtype_include=object)),
1347+
)
1348+
expected = ct.fit_transform(X)
1349+
onx = to_onnx(ct, X, target_opset=TARGET_OPSET)
1350+
sess = InferenceSession(
1351+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
1352+
)
1353+
names = [i.name for i in sess.get_inputs()]
1354+
got = sess.run(
1355+
None,
1356+
{
1357+
names[0]: X[names[0]].values.reshape((-1, 1)),
1358+
},
1359+
)
1360+
assert_almost_equal(expected, got[0])
1361+
13381362
@unittest.skipIf(not check_scikit_version(), reason="Scikit 0.21 too old")
13391363
def test_feature_selector_no_converter(self):
13401364
class ColumnSelector(TransformerMixin, BaseEstimator):

0 commit comments

Comments
 (0)