Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions skl2onnx/operator_converters/quadratic_discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def convert_quadratic_discriminant_analysis_classifier(
model = operator.raw_operator

n_classes = len(model.classes_)
if all(isinstance(i, str) for i in model.classes_):
class_type = onnx_proto.TensorProto.STRING
class_labels = [s.encode("utf-8") for s in model.classes_]
else:
class_type = onnx_proto.TensorProto.INT64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can have mixed types?

Copy link
Contributor Author

@alexarntzen alexarntzen Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sklearn supports mixed types. However, the other converters in sklearn-onnx, like LinearDiscriminantAnalysis, do not. For instance:

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from skl2onnx.common.data_types import DoubleTensorType
from skl2onnx import convert_sklearn

X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
y = np.array([1, 1, "b", "b"])
skl_model = LinearDiscriminantAnalysis()
skl_model.fit(X, y)
skl_model.classes_ = [1, "b"]  # force mixed types
onnx_model = convert_sklearn(
    skl_model,
    "scikit-learn",
    initial_types=[("inputs", DoubleTensorType([None, X.shape[1]]))],
    target_opset=16,
    options={"zipmap": False},
)

outputs:

ValueError: Label types must be all integers or all strings.

class_labels = [int(i) for i in model.classes_]

proto_dtype = guess_proto_type(operator.inputs[0].type)
if proto_dtype != onnx_proto.TensorProto.DOUBLE:
Expand Down Expand Up @@ -148,10 +154,7 @@ def convert_quadratic_discriminant_analysis_classifier(
apply_argmax(scope, [decision_fun], [argmax_out], container, axis=1)

classes = scope.get_unique_variable_name("classes")
container.add_initializer(
classes, onnx_proto.TensorProto.INT64, [n_classes], model.classes_
)

container.add_initializer(classes, class_type, [n_classes], class_labels)
container.add_node(
"ArrayFeatureExtractor",
[classes, argmax_out],
Expand Down
14 changes: 10 additions & 4 deletions skl2onnx/shape_calculators/quadratic_discriminant_analysis.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# SPDX-License-Identifier: Apache-2.0

from ..common._registration import register_shape_calculator
from ..common.data_types import Int64TensorType
from ..common.data_types import Int64TensorType, StringTensorType


def calculate_quadratic_discriminant_analysis_shapes(operator):
N = len(operator.raw_operator.classes_)
operator.outputs[0].type = Int64TensorType([1, N])
operator.outputs[1].type.shape = [None, N]
classes = operator.raw_operator.classes_
if all((isinstance(s, str)) for s in classes):
label_tensor_type = StringTensorType
else:
label_tensor_type = Int64TensorType

n_clasess = len(classes)
operator.outputs[0].type = label_tensor_type([1, None])
operator.outputs[1].type.shape = [None, n_clasess]


register_shape_calculator(
Expand Down
Loading