Skip to content

unable to use convert_sklearn for xgboost calibratedclassifier #1173

@anyuzhang2024

Description

@anyuzhang2024

met following error when try to use convert_sklearn
to convert calibrated xgboost classifer to onnx_model with skl2onnx


 <class 'xgboost.sklearn.XGBClassifier'>
File <command-3767976754122279>, line 51
     49 # Convert extracted XGBoost model to ONNX
     50 initial_type = [("float_input", FloatTensorType([None, X_train.shape[1]]))]
---> 51 onnx_model = convert_sklearn(calibrated_xgb, initial_types=initial_type)
     53 # Save ONNX model
     54 onnx_filename = "xgb_extracted_model.onnx"
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-7ca10a6f-be3f-434e-afeb-993e59ebdc60/lib/python3.10/site-packages/skl2onnx/convert.py:210, in convert_sklearn(model, name, initial_types, doc_string, target_opset, custom_conversion_functions, custom_shape_calculators, custom_parsers, options, intermediate, white_op, black_op, final_types, dtype, naming, model_optim, verbose)
    208 if verbose >= 1:
    209     print("[convert_sklearn] convert_topology")
--> 210 onnx_model = convert_topology(
    211     topology,
    212     name,
    213     doc_string,
    214     target_opset,
    215     options=options,
    216     remove_identity=model_optim and not intermediate,
    217     verbose=verbose,
    218 )

Also after updated the code toonnx_model = convert_sklearn(xgb_model_extracted, initial_types=initial_type) with xgb_model_extracted = calibrated_xgb.base_estimator still has following issue.

MissingShapeCalculator: Unable to find a shape calculator for type '<class 'xgboost.sklearn.XGBClassifier'>'.
It usually means the pipeline being converted contains a
transformer or a predictor with no corresponding converter
implemented in sklearn-onnx. If the converted is implemented
in another library, you need to register
the converted so that it can be used by sklearn-onnx (function
update_registered_converter). If the model is not yet covered
by sklearn-onnx, you may raise an issue to
https://github.com/onnx/sklearn-onnx/issues
to get the converter implemented or even contribute to the
project. If the model is a custom model, a new converter must
be implemented. Examples can be found in the gallery.
File <command-3767976754122279>, line 54

Is this a bug? and how can I best transform the calibratedclassifer xgboost object to the onnx format? ty!

below is the example code to regenerate the error:

%pip install onnx skl2onnx onnxruntime
import numpy as np
import xgboost as xgb
import joblib
import onnx
import skl2onnx
import onnxruntime as ort
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss
from tensorflow.keras.datasets import mnist

# Load Keras dataset (example: MNIST)
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Preprocess data: Flatten images and normalize
X_train = X_train.reshape(X_train.shape[0], -1) / 255.0
X_test = X_test.reshape(X_test.shape[0], -1) / 255.0

# Further split training data for calibration
X_train, X_calib, y_train, y_calib = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

# Define XGBoost model
xgb_model = xgb.XGBClassifier(
    objective="multi:softmax",  # Multi-class classification
    num_class=10,  # 10 classes for MNIST
    use_label_encoder=False,
    eval_metric="mlogloss"
)

# Train XGBoost model
xgb_model.fit(X_train, y_train)

# Apply sigmoid calibration
calibrated_xgb = CalibratedClassifierCV(xgb_model, method="sigmoid", cv="prefit")
calibrated_xgb.fit(X_calib, y_calib)

# Save calibrated model to joblib format
joblib.dump(calibrated_xgb, "calibrated_xgb_model.joblib")

# Load the joblib model
calibrated_xgb = joblib.load("calibrated_xgb_model.joblib")

# Convert class labels to strings
calibrated_xgb.classes_ = calibrated_xgb.classes_.astype(str)
xgb_model_extracted = calibrated_xgb.base_estimator

                             

# Convert extracted XGBoost model to ONNX
initial_type = [("float_input", FloatTensorType([None, X_train.shape[1]]))]
onnx_model = convert_sklearn(xgb_model_extracted, initial_types=initial_type)
#onnx_model = convert_sklearn(calibrated_xgb, initial_types=initial_type)

# Save ONNX model
onnx_filename = "xgb_extracted_model.onnx"
with open(onnx_filename, "wb") as f:
    f.write(onnx_model.SerializeToString())

print(f"Extracted XGBoost Model successfully converted to ONNX and saved as {onnx_filename}")

# Load ONNX model and make raw predictions
ort_session = ort.InferenceSession(onnx_filename)
onnx_inputs = {ort_session.get_inputs()[0].name: X_test.astype(np.float32)}
onnx_raw_prob = ort_session.run(None, onnx_inputs)[0]  # Probabilities before calibration

# Apply sigmoid calibration manually using the trained calibration model
calibrated_prob = calibrated_xgb.calibrators_[0].predict_proba(onnx_raw_prob)  # Apply calibration layer
onnx_pred = np.argmax(calibrated_prob, axis=1)  # Convert to class labels

# Evaluate ONNX model after applying calibration
accuracy_onnx_calibrated = accuracy_score(y_test, onnx_pred)
logloss_onnx_calibrated = log_loss(y_test, calibrated_prob)
print(f"Calibrated XGBoost Accuracy (ONNX, after calibration): {accuracy_onnx_calibrated:.4f}")
print(f"Calibrated XGBoost Log Loss (ONNX, after calibration): {logloss_onnx_calibrated:.4f}")

# Compare probability distributions
prob_diff = np.mean(np.abs(y_prob_joblib - calibrated_prob))
print(f"Average absolute probability difference between Joblib and ONNX-calibrated models: {prob_diff:.6f}")

I also tried to_onnx, didn't work as well...AttributeError: 'numpy.uint8' object has no attribute 'encode' File <command-3767976754122452>, line 47 45 # Convert to ONNX format 46 initial_type = [("float_input", FloatTensorType([None, X_train.shape[1]]))] ---> 47 onnx_model = to_onnx(calibrated_xgb, initial_types=initial_type) 49 # Save ONNX model 50 onnx_filename = "calibrated_xgb_model.onnx"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions