Description
The problem
Workflows which use dynamic case weights (calculated from non-predictor columns) error when attempting to predict()
.
This is following a string of issues, most recently tidymodels/hardhat#240 , and is (one of) the causes of tidymodels/hardhat#242 (though there's another issue there, too). The basic idea is that for some forms of modeling, for instance species abundance modeling (and other forms of presence/background data), it makes sense to calculate case weights "on the fly" for each fold (here, as the ratio of presence observations to background in each analysis set). The suggestion was do that by using step_mutate()
so that the case weights would be updated during cross-validation. This seems to cause some issues with the resulting workflow, though.
Reproducible example
set.seed(1107)
library(parsnip)
library(recipes)
#> Loading required package: dplyr
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
#>
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#>
#> step
library(workflows)
data(ames, package = "modeldata")
ames_model <- ames |>
mutate(cwts = hardhat::importance_weights(NA))
ames_recipe <- recipe(
formula = Sale_Price ~ Longitude + Latitude, # not to be used in real life...
data = ames_model
) |>
recipes::step_mutate(
cwts = hardhat::importance_weights(abs(Sale_Price - mean(Sale_Price))),
role = "case_weights"
)
ames_wflow <- workflow(preprocessor = ames_recipe) |>
add_model(linear_reg()) |>
add_case_weights(cwts) |>
fit(ames_model)
predict(ames_wflow, ames)
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `cwts = hardhat::importance_weights(abs(Sale_Price -
#> mean(Sale_Price)))`.
#> Caused by error:
#> ! object 'Sale_Price' not found
#> Backtrace:
#> ▆
#> 1. ├─stats::predict(ames_wflow, ames)
#> 2. ├─workflows:::predict.workflow(ames_wflow, ames)
#> 3. │ └─workflows:::forge_predictors(new_data, workflow)
#> 4. │ ├─hardhat::forge(new_data, blueprint = mold$blueprint)
#> 5. │ └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint)
#> 6. │ ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes)
#> 7. │ └─hardhat:::run_forge.default_recipe_blueprint(...)
#> 8. │ └─hardhat:::forge_recipe_default_process(...)
#> 9. │ ├─recipes::bake(object = rec, new_data = new_data)
#> 10. │ └─recipes:::bake.recipe(object = rec, new_data = new_data)
#> 11. │ ├─recipes::bake(step, new_data = new_data)
#> 12. │ └─recipes:::bake.step_mutate(step, new_data = new_data)
#> 13. │ ├─dplyr::mutate(new_data, !!!object$inputs)
#> 14. │ └─dplyr:::mutate.data.frame(new_data, !!!object$inputs)
#> 15. │ └─dplyr:::mutate_cols(.data, dplyr_quosures(...), by)
#> 16. │ ├─base::withCallingHandlers(...)
#> 17. │ └─dplyr:::mutate_col(dots[[i]], data, mask, new_columns)
#> 18. │ └─mask$eval_all_mutate(quo)
#> 19. │ └─dplyr (local) eval()
#> 20. ├─hardhat::importance_weights(abs(Sale_Price - mean(Sale_Price)))
#> 21. │ └─hardhat:::vec_cast_named(x, to = double(), x_arg = "x")
#> 22. │ └─vctrs::vec_cast(x, to, ..., call = call)
#> 23. └─base::.handleSimpleError(...)
#> 24. └─dplyr (local) h(simpleError(msg, call))
#> 25. └─rlang::abort(message, class = error_class, parent = parent, call = error_call)
Created on 2023-06-01 with reprex v2.0.2
Speculation
It seems like the issue is that hardhat:::run_forge.default_recipe_blueprint()
removes non-predictor variables (in this case, the outcome) prematurely, which then causes the recipe to not have the requisite variables for computation. Is there a workaround for this?