Skip to content

Conversation

@allglc
Copy link
Collaborator

@allglc allglc commented Nov 4, 2025

Description

Add the ability to handle multi-dimensional thresholds (lambdas)

  • 'predict_function` can now also be a general function (X *params) -> 0/1
  • predict_params is now an argument (even when only one-dimensional lambda) and docstring should be clearer
  • best_predict_param is a tuple for multi-dimensional parameters
  • added an automatic flag is_multi_dimensional_param (based on predict_params dimension)
  • _get_predictions_per_param handles general predict_functions, but will process all parameter values sequentially (i don't know how to do it easily in parallel).
  • get_predictions_per_param will check the prediction values in the calibration step (using a new argument is_calibration_step, because I don't want to check at test time as it is not necessary and it might happen that when predicting a single probability, which can happen at test time, a value of 0 or 1 and would raise a warning)
  • predictions are checked so that for one-dimensional parameters, predictions should not be 0 or 1 all the time, and for multi-dimensional parameters, they should be 0 or 1 all the time.
  • I don't think it's necessary to add a custom error message when the dimension of the parameters do not match the inputs of the general predict function, as the default message seems explicit (cf test_error_multi_dim_params_dim_mismatch)

To manage the two types of predict functions in __init__ there are a few options:

  1. Manual flag in the arguments of BinaryClassificationController
  2. auto detection based on the function signature (does it takes one or more arguments) but it can cause issues when the predict_proba can also take several arguments (e.g. XGBoost)
  3. add another argument e.g., predict_function_general and the user has to provide at least this or the original predict_function.
  4. Automatic check of predict_params.shape[1] > 1: the user has to provide a custom array of predict_params to do multi params

I will go with option 4.

@codecov-commenter
Copy link

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 100.00%. Comparing base (5ee9406) to head (72d7a5c).
⚠️ Report is 16 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##            master      #788    +/-   ##
==========================================
  Coverage   100.00%   100.00%            
==========================================
  Files           56        56            
  Lines         6325      6547   +222     
  Branches       360       378    +18     
==========================================
+ Hits          6325      6547   +222     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@allglc allglc marked this pull request as ready for review November 5, 2025 16:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants