同时筛选多个模型
library(tidymodels)tidymodels_prefer()data(concrete, package = 'modeldata')glimpse(concrete)
Registered S3 method overwritten by 'tune':method fromrequired_pkgs.model_spec parsnip-- Attaching packages --------------------------------------------- tidymodels 0.1.3 --√ broom 0.7.9 √ recipes 0.1.16√ dials 0.0.9 √ rsample 0.1.0√ dplyr 1.0.7 √ tibble 3.1.3√ ggplot2 3.3.5 √ tidyr 1.1.3√ infer 1.0.0 √ tune 0.1.6√ modeldata 0.1.1 √ workflows 0.2.3√ parsnip 0.1.7 √ workflowsets 0.1.0√ purrr 0.3.4 √ yardstick 0.0.8-- Conflicts ------------------------------------------------ tidymodels_conflicts() --x purrr::discard() masks scales::discard()x dplyr::filter() masks stats::filter()x dplyr::lag() masks stats::lag()x recipes::step() masks stats::step()* Use tidymodels_prefer() to resolve common conflicts.Rows: 1,030Columns: 9$ cement <dbl> 540.0, 540.0, 332.5, 332.5, 198.6, 266.0, 380.0, 380.0, ~$ blast_furnace_slag <dbl> 0.0, 0.0, 142.5, 142.5, 132.4, 114.0, 95.0, 95.0, 114.0,~$ fly_ash <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,~$ water <dbl> 162, 162, 228, 228, 192, 228, 228, 228, 228, 228, 192, 1~$ superplasticizer <dbl> 2.5, 2.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0~$ coarse_aggregate <dbl> 1040.0, 1055.0, 932.0, 932.0, 978.4, 932.0, 932.0, 932.0~$ fine_aggregate <dbl> 676.0, 676.0, 594.0, 594.0, 825.5, 670.0, 594.0, 594.0, ~$ age <int> 28, 28, 270, 365, 360, 90, 365, 28, 28, 28, 90, 28, 270,~$ compressive_strength <dbl> 79.99, 61.89, 40.27, 41.05, 44.30, 47.03, 43.70, 36.45, ~
concrete <-concrete %>%group_by(cement, blast_furnace_slag, fly_ash, water, superplasticizer,coarse_aggregate, fine_aggregate, age) %>%summarize(compressive_strength = mean(compressive_strength),.groups = "drop")nrow(concrete)
[1] 992
数据分割,交叉验证
set.seed(1501)concrete_split <- initial_split(concrete, strata = compressive_strength)concrete_train <- training(concrete_split)concrete_test <- testing(concrete_split)set.seed(1502)concrete_folds <-vfold_cv(concrete_train, strata = compressive_strength, repeats = 5)
简单的预处理
normalized_rec <-recipe(compressive_strength ~ ., data = concrete_train) %>%step_normalize(all_predictors())poly_recipe <-normalized_rec %>%step_poly(all_predictors()) %>%step_interact(~ all_predictors():all_predictors())
建立多个模型
library(rules)library(baguette)linear_reg_spec <-linear_reg(penalty = tune(), mixture = tune()) %>%set_engine("glmnet")nnet_spec <-mlp(hidden_units = tune(), penalty = tune(), epochs = tune()) %>%set_engine("nnet", MaxNWts = 2600) %>%set_mode("regression")mars_spec <-mars(prod_degree = tune()) %>% #<- use GCV to choose termsset_engine("earth") %>%set_mode("regression")svm_r_spec <-svm_rbf(cost = tune(), rbf_sigma = tune()) %>%set_engine("kernlab") %>%set_mode("regression")svm_p_spec <-svm_poly(cost = tune(), degree = tune()) %>%set_engine("kernlab") %>%set_mode("regression")knn_spec <-nearest_neighbor(neighbors = tune(), dist_power = tune(), weight_func = tune()) %>%set_engine("kknn") %>%set_mode("regression")cart_spec <-decision_tree(cost_complexity = tune(), min_n = tune()) %>%set_engine("rpart") %>%set_mode("regression")bag_cart_spec <-bag_tree() %>%set_engine("rpart", times = 50L) %>%set_mode("regression")rf_spec <-rand_forest(mtry = tune(), min_n = tune(), trees = 1000) %>%set_engine("ranger") %>%set_mode("regression")xgb_spec <-boost_tree(tree_depth = tune(), learn_rate = tune(), loss_reduction = tune(),min_n = tune(), sample_size = tune(), trees = tune()) %>%set_engine("xgboost") %>%set_mode("regression")cubist_spec <-cubist_rules(committees = tune(), neighbors = tune()) %>%set_engine("Cubist")
nnet_param <-nnet_spec %>%parameters() %>%update(hidden_units = hidden_units(c(1, 27)))
不同的预处理步骤
normalized <-workflow_set(preproc = list(normalized = normalized_rec),models = list(SVM_radial = svm_r_spec, SVM_poly = svm_p_spec,KNN = knn_spec, neural_network = nnet_spec))normalized

随便挑选一个看看
normalized %>% pull_workflow(id = "normalized_KNN")
== Workflow ===========================================================================Preprocessor: RecipeModel: nearest_neighbor()-- Preprocessor -----------------------------------------------------------------------1 Recipe Step* step_normalize()-- Model ------------------------------------------------------------------------------K-Nearest Neighbor Model Specification (regression)Main Arguments:neighbors = tune()weight_func = tune()dist_power = tune()Computational engine: kknn
normalized <-normalized %>%option_add(param = nnet_param, id = "normalized_neural_network")normalized

model_vars <-workflow_variables(outcomes = compressive_strength,predictors = everything())no_pre_proc <-workflow_set(preproc = list(simple = model_vars),models = list(MARS = mars_spec, CART = cart_spec, CART_bagged = bag_cart_spec,RF = rf_spec, boosting = xgb_spec, Cubist = cubist_spec))no_pre_proc

with_features <-workflow_set(preproc = list(full_quad = poly_recipe),models = list(linear_reg = linear_reg_spec, KNN = knn_spec))
all_workflows <-bind_rows(no_pre_proc, normalized, with_features) %>%# Make the workflow ID's a little more simple:mutate(wflow_id = gsub("(simple_)|(normalized_)", "", wflow_id))all_workflows

下面开始训练模型
这一步非常浪费时间,给大家一个参考,我的配置是AMD 5900X,内存是Fury3600 32G X 2
grid_ctrl <-control_grid(save_pred = TRUE,parallel_over = "everything",save_workflow = TRUE)grid_results <-all_workflows %>%workflow_map(seed = 1503,resamples = concrete_folds,grid = 25,control = grid_ctrl)
i 1 of 12 tuning: MARS√ 1 of 12 tuning: MARS (11.2s)i 2 of 12 tuning: CART√ 2 of 12 tuning: CART (1m 28.1s)i No tuning parameters. `fit_resamples()` will be attemptedi 3 of 12 resampling: CART_bagged√ 3 of 12 resampling: CART_bagged (2m 22.4s)i 4 of 12 tuning: RFi Creating pre-processing data to finalize unknown parameter: mtry√ 4 of 12 tuning: RF (3m 7.6s)i 5 of 12 tuning: boosting√ 5 of 12 tuning: boosting (4m 50.9s)i 6 of 12 tuning: Cubist√ 6 of 12 tuning: Cubist (5m 6.5s)i 7 of 12 tuning: SVM_radial√ 7 of 12 tuning: SVM_radial (2m 4.4s)i 8 of 12 tuning: SVM_poly√ 8 of 12 tuning: SVM_poly (14m 4.8s)i 9 of 12 tuning: KNN√ 9 of 12 tuning: KNN (2m 57.5s)i 10 of 12 tuning: neural_networkWarning: The `...` are not used in this function but one or more objects were passed: 'param'√ 10 of 12 tuning: neural_network (2m 48.1s)i 11 of 12 tuning: full_quad_linear_reg√ 11 of 12 tuning: full_quad_linear_reg (2m 40.8s)i 12 of 12 tuning: full_quad_KNN√ 12 of 12 tuning: full_quad_KNN (23m 7.7s)
结果
grid_results

按照某一标准(RMSE)排列结果:
grid_results %>%rank_results() %>%filter(.metric == "rmse") %>%select(model, .config, rmse = mean, rank)

可视化结果
autoplot(grid_results,rank_metric = "rmse", # <- how to order modelsmetric = "rmse", # <- which metric to visualizeselect_best = TRUE # <- one point per workflow)

autoplot(grid_results, id = "Cubist", metric = "rmse")

更加快捷的方式筛选多个模型
library(finetune)race_ctrl <-control_race(save_pred = TRUE,parallel_over = "everything",save_workflow = TRUE)race_results <-all_workflows %>%workflow_map("tune_race_anova", # 这个方法更快seed = 1503,resamples = concrete_folds,grid = 25,control = race_ctrl)
可视化结果:
autoplot(race_results,rank_metric = "rmse",metric = "rmse",select_best = TRUE)

比较一下两种方法,看差别大不大
matched_results <-rank_results(race_results, select_best = TRUE) %>%select(wflow_id, .metric, race = mean, config_race = .config) %>%inner_join(rank_results(grid_results, select_best = TRUE) %>%select(wflow_id, .metric, complete = mean,config_complete = .config, model),by = c("wflow_id", ".metric"),) %>%filter(.metric == "rmse")matched_results %>%ggplot(aes(x = complete, y = race)) +geom_abline(lty = 3) +geom_point(aes(col = model)) +coord_obs_pred() +labs(x = "Complete Grid RMSE", y = "Racing RMSE")

选择最后的模型
best_results <-race_results %>%extract_workflow_set_result("boosting") %>%select_best(metric = "rmse")best_results

boosting_test_results <-race_results %>%extract_workflow("boosting") %>%finalize_workflow(best_results) %>%last_fit(split = concrete_split)
collect_metrics(boosting_test_results)

boosting_test_results %>%collect_predictions() %>%ggplot(aes(x = compressive_strength, y = .pred)) +geom_abline(col = "green", lty = 2) +geom_point(alpha = 0.5) +coord_obs_pred() +labs(x = "observed", y = "predicted")

