1
$\begingroup$

I'm doing a test run of the Gradient Boosting Machine algorithm on the iris data with the caret package.

library(caret)
library(gbm)
data(iris)

set.seed(123)
inTraining <- createDataPartition(iris$Species, p = .75, list = FALSE)
training <- iris[ inTraining,]
testing  <- iris[-inTraining,]

gbmGrid <-  expand.grid(interaction.depth = c(1, 2, 3), 
                        n.trees = (1:10)*1000, 
                        shrinkage = c(0.001, 0.005, 0.01, 0.05, 0.1),
                        n.minobsinnode = c(1, 2, 5, 10, 15, 20))

fitControl <- trainControl(
  classProbs = TRUE,
  method = "repeatedcv",
  number = 10,
  repeats = 10,
  allowParallel = T)

set.seed(234)
gbmFit2 <- train(Species ~ ., 
                 data = training, 
                 method = "gbm", 
                 trControl = fitControl, 
                 verbose = FALSE, 
                 tuneGrid = gbmGrid)

I'm achieving excellent Accuracy metrics, however the predicted probabilities for the Species values in the test data are fairly evenly split. I expected GBM would return predicted probabilities of 90%+ for the correctly predicted Species value rather than in the 35%-40% range.

predict(gbmFit2, newdata=testing, type="prob")
     setosa versicolor virginica
1 0.3826163  0.3086751 0.3087086
2 0.3826643  0.3086374 0.3086983
3 0.3826681  0.3086355 0.3086964
4 0.3811067  0.3114695 0.3074237
5 0.3811067  0.3114695 0.3074237
...
32 0.3077245  0.3568080 0.3354674
33 0.3153934  0.3275473 0.3570593
34 0.3097463  0.3525782 0.3376756
35 0.3065883  0.3151160 0.3782957
36 0.3078244  0.3122151 0.3799605

Did I misspecify my model?

$\endgroup$
  • 1
    $\begingroup$ I'm not familiar with how gradient boosted models compute probabilities, but my sense is that they are not calibrated (in the statistical sense of the word). That might be a reason why probability estimates are off. $\endgroup$ – Demetri Pananos Mar 22 at 4:46
  • $\begingroup$ @DemetriPananos Ah, thank you! I've found several articles on predicted probability calibration that ought to help. $\endgroup$ – RobertF Mar 22 at 4:59
  • 1
    $\begingroup$ I think I have answered exactly this questions here: Biased prediction (overestimation) for xgboost. The references in the answer should be helpful. $\endgroup$ – usεr11852 Mar 22 at 11:00
  • $\begingroup$ @usεr11852 Thanks, very helpful. I'd like to use the GBM's predicted probabilities as weights in a followup case-control analysis matching treatments to controls. $\endgroup$ – RobertF Mar 22 at 14:15
0
$\begingroup$

I'm getting good results by applying Platt scaling to the predicted probabilities for each of the iris Species classes from the Gradient Boosting Machine model. Instead of binomial logistic regression I'm using the multinomial logistic regression model.

library(nnet)
predict_gbm = predict(gbmFit2, newdata=iris, type="prob")
iris_preds <- data.frame(cbind(testing, predict_gbm))
multinom_iris_calib <- multinom(Species ~ setosa + versicolor + virginica, data = iris_preds)
predict_multinom_iris_calib = fitted(multinom_iris_calib)
predict_multinom_iris_calib <- data.frame(cbind(testing, predict_multinom_iris_calib))
predict_multinom_iris_calib[,5:8]
           Species       setosa  versicolor    virginica
    1       setosa 0.9924330938 0.007566906 3.189546e-12
    5       setosa 0.9924869997 0.007513000 3.122451e-12
    7       setosa 0.9924908536 0.007509146 3.117173e-12
    13      setosa 0.9897471351 0.010252865 6.159896e-12
    14      setosa 0.9897471351 0.010252865 6.159896e-12
    19      setosa 0.9924455961 0.007554404 3.160900e-12
    20      setosa 0.9924369750 0.007563025 3.184155e-12
    26      setosa 0.9897471351 0.010252865 6.159896e-12
    30      setosa 0.9908937928 0.009106207 4.749171e-12
    34      setosa 0.9924330938 0.007566906 3.189546e-12
    44      setosa 0.9924920468 0.007507953 3.113992e-12
    47      setosa 0.9924330938 0.007566906 3.189546e-12
    53  versicolor 0.0111471614 0.430054907 5.587979e-01
    59  versicolor 0.0027915088 0.879689514 1.175190e-01
    62  versicolor 0.0067711749 0.941213020 5.201580e-02
    64  versicolor 0.0041798273 0.913490187 8.232999e-02
    66  versicolor 0.0072934320 0.944243978 4.846259e-02
    72  versicolor 0.0012548141 0.780996058 2.177491e-01
    76  versicolor 0.0064162669 0.938643715 5.494002e-02
    78  versicolor 0.0172668357 0.499051229 4.836819e-01
    80  versicolor 0.0008895809 0.738176308 2.609341e-01
    85  versicolor 0.0068591641 0.944825113 4.831572e-02
    89  versicolor 0.0068698892 0.944839179 4.829093e-02
    99  versicolor 0.0008884476 0.737538242 2.615733e-01
    104  virginica 0.0003896792 0.020226793 9.793835e-01
    105  virginica 0.0003855347 0.019409330 9.802051e-01
    106  virginica 0.0003543512 0.018632033 9.810136e-01
    107  virginica 0.0019856015 0.851050054 1.469643e-01
    116  virginica 0.0005269779 0.023493314 9.759797e-01
    119  virginica 0.0003167019 0.019191993 9.804913e-01
    126  virginica 0.0004884086 0.022714154 9.767974e-01
    127  virginica 0.0007718212 0.263689615 7.355386e-01
    135  virginica 0.0140170968 0.477610130 5.083728e-01
    139  virginica 0.0016218328 0.358758265 6.396199e-01
    148  virginica 0.0004809264 0.023503371 9.760157e-01
    149  virginica 0.0008173819 0.030148626 9.690340e-01
$\endgroup$

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service, privacy policy and cookie policy

Not the answer you're looking for? Browse other questions tagged or ask your own question.