Skip to content

Commit f99cdff

Browse files
alexarntzenAlexander ArntzenAlexander Arntzen
authored
Fix QDA converter crashing on string labels and incorrect shape calculation (#1169)
* fix crashing on string labels in QDA and incorrect shape calculation Signed-off-by: Alexander Arntzen <[email protected]> * run black Signed-off-by: Alexander Arntzen <[email protected]> * add test with String labels for QDA Signed-off-by: Alexander Arntzen <[email protected]> --------- Signed-off-by: Alexander Arntzen <[email protected]> Co-authored-by: Alexander Arntzen <[email protected]> Co-authored-by: Alexander Arntzen <[email protected]>
1 parent 73413b1 commit f99cdff

File tree

3 files changed

+47
-8
lines changed

3 files changed

+47
-8
lines changed

skl2onnx/operator_converters/quadratic_discriminant_analysis.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def convert_quadratic_discriminant_analysis_classifier(
3131
model = operator.raw_operator
3232

3333
n_classes = len(model.classes_)
34+
if all(isinstance(i, str) for i in model.classes_):
35+
class_type = onnx_proto.TensorProto.STRING
36+
class_labels = [s.encode("utf-8") for s in model.classes_]
37+
else:
38+
class_type = onnx_proto.TensorProto.INT64
39+
class_labels = [int(i) for i in model.classes_]
3440

3541
proto_dtype = guess_proto_type(operator.inputs[0].type)
3642
if proto_dtype != onnx_proto.TensorProto.DOUBLE:
@@ -148,10 +154,7 @@ def convert_quadratic_discriminant_analysis_classifier(
148154
apply_argmax(scope, [decision_fun], [argmax_out], container, axis=1)
149155

150156
classes = scope.get_unique_variable_name("classes")
151-
container.add_initializer(
152-
classes, onnx_proto.TensorProto.INT64, [n_classes], model.classes_
153-
)
154-
157+
container.add_initializer(classes, class_type, [n_classes], class_labels)
155158
container.add_node(
156159
"ArrayFeatureExtractor",
157160
[classes, argmax_out],

skl2onnx/shape_calculators/quadratic_discriminant_analysis.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from ..common._registration import register_shape_calculator
4-
from ..common.data_types import Int64TensorType
4+
from ..common.data_types import Int64TensorType, StringTensorType
55

66

77
def calculate_quadratic_discriminant_analysis_shapes(operator):
8-
N = len(operator.raw_operator.classes_)
9-
operator.outputs[0].type = Int64TensorType([1, N])
10-
operator.outputs[1].type.shape = [None, N]
8+
classes = operator.raw_operator.classes_
9+
if all((isinstance(s, str)) for s in classes):
10+
label_tensor_type = StringTensorType
11+
else:
12+
label_tensor_type = Int64TensorType
13+
14+
n_clasess = len(classes)
15+
operator.outputs[0].type = label_tensor_type([1, None])
16+
operator.outputs[1].type.shape = [None, n_clasess]
1117

1218

1319
register_shape_calculator(

tests/test_sklearn_quadratic_discriminant_analysis_converter.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,36 @@ def test_model_qda_2c2f_float(self):
4949
basename="SklearnQDA_2c2f_Float",
5050
)
5151

52+
@unittest.skipIf(
53+
pv.Version(sklearn.__version__) < pv.Version("1.0"), reason="scikit-learn<1.0"
54+
)
55+
@unittest.skipIf(
56+
pv.Version(onnx_version) < pv.Version("1.11"), reason="fails with onnx 1.10"
57+
)
58+
def test_model_qda_2c2f_float_string_labels(self):
59+
# 2 classes, 2 features, string_labels
60+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
61+
y = np.array(["apple", "apple", "apple", "banana", "banana", "banana"])
62+
X_test = np.array([[-0.8, -1], [0.8, 1]])
63+
64+
skl_model = QuadraticDiscriminantAnalysis()
65+
skl_model.fit(X, y)
66+
67+
onnx_model = convert_sklearn(
68+
skl_model,
69+
"scikit-learn QDA",
70+
[("input", FloatTensorType([None, X.shape[1]]))],
71+
target_opset=TARGET_OPSET,
72+
)
73+
74+
self.assertIsNotNone(onnx_model)
75+
dump_data_and_model(
76+
X_test.astype(np.float32),
77+
skl_model,
78+
onnx_model,
79+
basename="SklearnQDA_2c2f_Float_String_labels",
80+
)
81+
5282
@unittest.skipIf(
5383
pv.Version(sklearn.__version__) < pv.Version("1.0"), reason="scikit-learn<1.0"
5484
)

0 commit comments

Comments
 (0)