-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Description
Below are some thoughts on how we could create the optimal model usage workflow. The idea is to enable all models implemented as part of TabArena to be easily usable by anyone on new datasets, just like any other scikit-learn compatible model.
The current workflow is shown here: https://github.com/TabArena/tabarena_benchmarking_examples/blob/main/tabarena_minimal_example/run_tabarena_model.py
My ideal workflow would be something like this:
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
# Get Data
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
# Import a TabArena model
from tabarena.pipelines import RealMLP # or any other model
# Train TabArena Model
clf = RealMLP()
clf.fit(X=X_train, y=y_train)
# Predict and score
prediction_probabilities = clf.predict_proba(X=X_test)
print("ROC AUC:", roc_auc_score(y_test, prediction_probabilities))Some requirements that would be nice to have:
- The pipeline has support for model-agnostic preprocessing (by default True)
- The model ideal should be sklearn compatible up to a certain degree. We might also need explicit classifier/regressor wrapper interfaces as a result. See my PHE code as an example for this https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/post_hoc_ensembles/sklearn_interface.py
- One should be able to set per parameter if the model is standalone, a bagging model, or the full AG AutoML pipeline (i.e., the different options from here https://auto.gluon.ai/stable/tutorials/tabular/advanced/tabular-custom-model.html)
- Ideally, the model interface would have a docstring, parameters, and documentation for all relevant hyperparameters of the model and easy access to a default search space (like RealMLPPipeline.search_space)
- Ideally, we could have (by default) almost no output/logging from running the model or bagged versions.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request