数据探索
library(tidyverse)## -- Attaching packages ----------------------------------------------------------- tidyverse 1.3.1 --## v ggplot2 3.3.3 v purrr 0.3.4## v tibble 3.1.1 v dplyr 1.0.6## v tidyr 1.1.3 v stringr 1.4.0## v readr 1.4.0 v forcats 0.5.1## -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --## x dplyr::filter() masks stats::filter()## x dplyr::lag() masks stats::lag()sf_trees <- read_csv("../datasets/tidytuesday/data/2020/2020-01-28/sf_trees.csv")## ## -- Column specification ----------------------------------------------------------------------------## cols(## tree_id = col_double(),## legal_status = col_character(),## species = col_character(),## address = col_character(),## site_order = col_double(),## site_info = col_character(),## caretaker = col_character(),## date = col_date(format = ""),## dbh = col_double(),## plot_size = col_character(),## latitude = col_double(),## longitude = col_double()## )trees_df <- sf_trees %>% mutate( legal_status = case_when( legal_status == "DPW Maintained" ~ legal_status, TRUE ~ "Other" ), plot_size = parse_number(plot_size) ) %>% select(-address) %>% na.omit() %>% mutate_if(is.character, factor)
trees_df %>% ggplot(aes(longitude, latitude, color = legal_status)) + geom_point(size = 0.5, alpha = 0.4) + labs(color = NULL) + theme_bw()+ theme(panel.border = element_blank())

trees_df %>% count(legal_status, caretaker) %>% add_count(caretaker, wt = n, name = "caretaker_count") %>% filter(caretaker_count > 50) %>% group_by(legal_status) %>% mutate(percent_legal = n / sum(n)) %>% ggplot(aes(percent_legal, caretaker, fill = legal_status)) + geom_col(position = "dodge") + labs( fill = NULL, x = "% of trees in each category" )

建立模型
library(tidymodels)## -- Attaching packages ---------------------------------------------------------- tidymodels 0.1.3 --## v broom 0.7.6 v rsample 0.0.9 ## v dials 0.0.9 v tune 0.1.5 ## v infer 0.5.4 v workflows 0.2.2 ## v modeldata 0.1.0 v workflowsets 0.0.2 ## v parsnip 0.1.5 v yardstick 0.0.8 ## v recipes 0.1.16## -- Conflicts ------------------------------------------------------------- tidymodels_conflicts() --## x scales::discard() masks purrr::discard()## x dplyr::filter() masks stats::filter()## x recipes::fixed() masks stringr::fixed()## x dplyr::lag() masks stats::lag()## x yardstick::spec() masks readr::spec()## x recipes::step() masks stats::step()## * Use tidymodels_prefer() to resolve common conflicts.trees_df <- trees_df[1:2000,]set.seed(123)trees_split <- initial_split(trees_df, strata = legal_status)trees_train <- training(trees_split)trees_test <- testing(trees_split)
tree_rec <- recipe(legal_status ~ ., data = trees_train) %>% update_role(tree_id, new_role = "ID") %>% step_other(species, caretaker, threshold = 0.01) %>% step_other(site_info, threshold = 0.005) %>% step_dummy(all_nominal(), -all_outcomes()) %>% step_date(date, features = c("year")) %>% step_rm(date) %>% step_downsample(legal_status)tree_prep <- prep(tree_rec)juiced <- juice(tree_prep)
tune_spec <- rand_forest( mtry = tune(), trees = 1000, min_n = tune()) %>% set_mode("classification") %>% set_engine("ranger")
tune_wf <- workflow() %>% add_recipe(tree_rec) %>% add_model(tune_spec)
调整参数
set.seed(234)trees_folds <- vfold_cv(trees_train,v = 5)
doParallel::registerDoParallel()set.seed(345)tune_res <- tune_grid( tune_wf, resamples = trees_folds, grid = 5)## i Creating pre-processing data to finalize unknown parameter: mtrytune_res## # Tuning results## # 5-fold cross-validation ## # A tibble: 5 x 4## splits id .metrics .notes ## <list> <chr> <list> <list> ## 1 <split [1200/301]> Fold1 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>## 2 <split [1201/300]> Fold2 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>## 3 <split [1201/300]> Fold3 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>## 4 <split [1201/300]> Fold4 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>## 5 <split [1201/300]> Fold5 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
tune_res %>% collect_metrics() %>% filter(.metric == "roc_auc") %>% select(mean, min_n, mtry) %>% pivot_longer(min_n:mtry, values_to = "value", names_to = "parameter" ) %>% ggplot(aes(value, mean, color = parameter)) + geom_point(show.legend = FALSE) + facet_wrap(~parameter, scales = "free_x") + labs(x = NULL, y = "AUC")

rf_grid <- grid_regular( mtry(range = c(20, 25)), min_n(range = c(2, 5)), levels = 5)rf_grid## # A tibble: 20 x 2## mtry min_n## <int> <int>## 1 20 2## 2 21 2## 3 22 2## 4 23 2## 5 25 2## 6 20 3## 7 21 3## 8 22 3## 9 23 3## 10 25 3## 11 20 4## 12 21 4## 13 22 4## 14 23 4## 15 25 4## 16 20 5## 17 21 5## 18 22 5## 19 23 5## 20 25 5
set.seed(456)regular_res <- tune_grid( tune_wf, resamples = trees_folds, grid = rf_grid)regular_res## # Tuning results## # 5-fold cross-validation ## # A tibble: 5 x 4## splits id .metrics .notes ## <list> <chr> <list> <list> ## 1 <split [1200/301]> Fold1 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>## 2 <split [1201/300]> Fold2 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>## 3 <split [1201/300]> Fold3 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>## 4 <split [1201/300]> Fold4 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>## 5 <split [1201/300]> Fold5 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
评价模型
regular_res %>% collect_metrics() %>% filter(.metric == "roc_auc") %>% mutate(min_n = factor(min_n)) %>% ggplot(aes(mtry, mean, color = min_n)) + geom_line(alpha = 0.5, size = 1.5) + geom_point() + labs(y = "AUC")

best_auc <- select_best(regular_res, "roc_auc")final_rf <- finalize_model( tune_spec, best_auc)final_rf## Random Forest Model Specification (classification)## ## Main Arguments:## mtry = 23## trees = 1000## min_n = 3## ## Computational engine: ranger
library(vip)## ## Attaching package: 'vip'## The following object is masked from 'package:utils':## ## vifinal_rf %>% set_engine("ranger", importance = "permutation") %>% fit(legal_status ~ ., data = juice(tree_prep) %>% select(-tree_id) ) %>% vip(geom = "point")

final_wf <- tune_wf %>% finalize_workflow(best_auc)final_res <- final_wf %>% last_fit(trees_split)final_res %>% collect_metrics()## # A tibble: 2 x 4## .metric .estimator .estimate .config ## <chr> <chr> <dbl> <chr> ## 1 accuracy binary 0.780 Preprocessor1_Model1## 2 roc_auc binary 0.901 Preprocessor1_Model1
predict(final_res$.workflow[[1]], trees_train[1:10,])## # A tibble: 10 x 1## .pred_class ## <fct> ## 1 DPW Maintained## 2 DPW Maintained## 3 DPW Maintained## 4 DPW Maintained## 5 DPW Maintained## 6 DPW Maintained## 7 DPW Maintained## 8 Other ## 9 Other ## 10 DPW Maintained
final_res1 <- last_fit(final_rf, legal_status ~ ., trees_split)predict(final_res1$.workflow[[1]], trees_train[1:10,])## # A tibble: 10 x 1## .pred_class ## <fct> ## 1 DPW Maintained## 2 DPW Maintained## 3 DPW Maintained## 4 DPW Maintained## 5 DPW Maintained## 6 DPW Maintained## 7 DPW Maintained## 8 Other ## 9 Other ## 10 DPW Maintained
final_res2 <- fit(final_rf, legal_status ~ ., trees_train)predict(final_res2, trees_train[1:10,])## # A tibble: 10 x 1## .pred_class ## <fct> ## 1 DPW Maintained## 2 DPW Maintained## 3 DPW Maintained## 4 DPW Maintained## 5 DPW Maintained## 6 DPW Maintained## 7 DPW Maintained## 8 Other ## 9 Other ## 10 DPW Maintained
final_res %>% collect_predictions() %>% mutate(correct = case_when( legal_status == .pred_class ~ "Correct", TRUE ~ "Incorrect" )) %>% bind_cols(trees_test) %>% ggplot(aes(longitude, latitude, color = correct)) + geom_point(size = 0.5, alpha = 0.5) + labs(color = NULL) + scale_color_manual(values = c("gray80", "darkred"))## New names:## * legal_status -> legal_status...6## * legal_status -> legal_status...10
