Skip to content

Implement importance weighting in stacked ensembles #233

Open
@mark-burdon

Description

@mark-burdon

Feature

Similar to tidymodels/probably#159.

When producing a stacked ensemble of predictions, although the base models may have been trained using importance weights, these weights don't seem to be passed on to the stacking procedure. In my use case I use importance weights to give additional weight to the minority class in a highly imbalanced dataset.

Here's a reprex - not a perfect test (comparing two "best" base learners with a stacked ensemble of several candidates) but demonstrates that the weights aren't being passed on and therefore the stacked ensemble is much less likely to predict the (highly weighted) minority class, compared to the base learners:

library(tidyverse)
library(tidymodels)
set.seed(100)

# Create noisy, imbalanced data, add weight column giving roughly equal overall weight
df <- caret::twoClassSim(n=1000, intercept = -12, linearVars = 2, noiseVars = 5,corrVars = 3, corrType = "AR1") |>
  dplyr::mutate(weights = dplyr::if_else(Class == "Class1",
                                         true =  0.15,
                                         false = 0.85)) |>
  dplyr::mutate(weights = hardhat::importance_weights(weights),
                Class = case_match(Class, "Class1" ~ "Majority",
                                   "Class2" ~ "Minority"))

# Create recipe and logistic regression specification
glm_recipe <- recipes::recipe(x = df, formula = Class ~ .)

glm_spec <- parsnip::logistic_reg(mode = "classification",
                                  engine = "glm")

# Combine into workflow
glm_wf <- workflows::workflow(preprocessor = glm_recipe,
                              spec = glm_spec) |>
  workflows::add_case_weights(col = weights)

# Now let's do an xgBoost one
# Don't include all variables so it's not amazingly accurate
xgb_recipe <-  recipes::recipe(x = df,
                               formula = Class ~ Linear1 + Linear2 + Nonlinear1 + Noise1 + Noise2 + weights)

xgb_spec <- parsnip::boost_tree(mode = "classification",
                                engine = "xgboost",
                                trees = 100,
                                tree_depth = tune(),
                                learn_rate = 0.1)

xgb_wf <- workflows::workflow(preprocessor = xgb_recipe,
                              spec = xgb_spec) |>
  workflows::add_case_weights(col = weights)

# Create resamples for model fitting
resamples <- rsample::vfold_cv(data = df,
                               v = 5,
                               strata = Class)

# Cross-validate the models
glm_cv <- fit_resamples(glm_wf, resamples = resamples, control = control_stack_resamples())
xgb_cv <- tune_grid(xgb_wf, resamples = resamples, control = control_stack_resamples())

# Stack the ensemble
model_stack <- stacks() |>
  add_candidates(candidates = glm_cv) |>
  add_candidates(candidates = xgb_cv) |>
  blend_predictions() |>
  fit_members()

# Add stacked predictions to dataframe
df <- df |>
  bind_cols(predict(model_stack, df, type = "prob")|>
              rename(.pred_Majority_stack =.pred_Majority,
                     .pred_Minority_stack =.pred_Minority))

# Fit best individual models
best_glm <- fit_best(glm_cv)
best_xgb <- fit_best(xgb_cv)

# Append individual model predictions to the data
df <- df |>
  bind_cols(predict(best_glm, df, type = "prob") |>
                    rename(.pred_Majority_glm =.pred_Majority,
                           .pred_Minority_glm =.pred_Minority)) |>
  bind_cols(predict(best_xgb, df, type = "prob") |>
              rename(.pred_Majority_xgb =.pred_Majority,
                     .pred_Minority_xgb =.pred_Minority))

# Bring together and visualise the predictions
 df |>
  select(Class, starts_with(".pred_Min")) |>
  pivot_longer(starts_with(".pred_Min"),
               names_to = "model",
               values_to = ".pred_min",
               names_prefix = ".pred_Minority_") |>
  ggplot(aes(x = .pred_min,colour= model,  fill = model)) + 
  geom_density(alpha = 0.2) +
  theme_bw() +
  labs(x = "Predicted probability of being in the minority class") +
   facet_wrap(~Class, nrow = 2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions