数据探索

  1. library(tidyverse)
  2. ## -- Attaching packages ----------------------------------------------------------- tidyverse 1.3.1 --
  3. ## v ggplot2 3.3.3 v purrr 0.3.4
  4. ## v tibble 3.1.1 v dplyr 1.0.6
  5. ## v tidyr 1.1.3 v stringr 1.4.0
  6. ## v readr 1.4.0 v forcats 0.5.1
  7. ## -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --
  8. ## x dplyr::filter() masks stats::filter()
  9. ## x dplyr::lag() masks stats::lag()
  10. sf_trees <- read_csv("../datasets/tidytuesday/data/2020/2020-01-28/sf_trees.csv")
  11. ##
  12. ## -- Column specification ----------------------------------------------------------------------------
  13. ## cols(
  14. ## tree_id = col_double(),
  15. ## legal_status = col_character(),
  16. ## species = col_character(),
  17. ## address = col_character(),
  18. ## site_order = col_double(),
  19. ## site_info = col_character(),
  20. ## caretaker = col_character(),
  21. ## date = col_date(format = ""),
  22. ## dbh = col_double(),
  23. ## plot_size = col_character(),
  24. ## latitude = col_double(),
  25. ## longitude = col_double()
  26. ## )
  27. trees_df <- sf_trees %>%
  28. mutate(
  29. legal_status = case_when(
  30. legal_status == "DPW Maintained" ~ legal_status,
  31. TRUE ~ "Other"
  32. ),
  33. plot_size = parse_number(plot_size)
  34. ) %>%
  35. select(-address) %>%
  36. na.omit() %>%
  37. mutate_if(is.character, factor)
  1. trees_df %>%
  2. ggplot(aes(longitude, latitude, color = legal_status)) +
  3. geom_point(size = 0.5, alpha = 0.4) +
  4. labs(color = NULL) +
  5. theme_bw()+
  6. theme(panel.border = element_blank())

tidymodels-exercise-05 - 图1

  1. trees_df %>%
  2. count(legal_status, caretaker) %>%
  3. add_count(caretaker, wt = n, name = "caretaker_count") %>%
  4. filter(caretaker_count > 50) %>%
  5. group_by(legal_status) %>%
  6. mutate(percent_legal = n / sum(n)) %>%
  7. ggplot(aes(percent_legal, caretaker, fill = legal_status)) +
  8. geom_col(position = "dodge") +
  9. labs(
  10. fill = NULL,
  11. x = "% of trees in each category"
  12. )

tidymodels-exercise-05 - 图2

建立模型

  1. library(tidymodels)
  2. ## -- Attaching packages ---------------------------------------------------------- tidymodels 0.1.3 --
  3. ## v broom 0.7.6 v rsample 0.0.9
  4. ## v dials 0.0.9 v tune 0.1.5
  5. ## v infer 0.5.4 v workflows 0.2.2
  6. ## v modeldata 0.1.0 v workflowsets 0.0.2
  7. ## v parsnip 0.1.5 v yardstick 0.0.8
  8. ## v recipes 0.1.16
  9. ## -- Conflicts ------------------------------------------------------------- tidymodels_conflicts() --
  10. ## x scales::discard() masks purrr::discard()
  11. ## x dplyr::filter() masks stats::filter()
  12. ## x recipes::fixed() masks stringr::fixed()
  13. ## x dplyr::lag() masks stats::lag()
  14. ## x yardstick::spec() masks readr::spec()
  15. ## x recipes::step() masks stats::step()
  16. ## * Use tidymodels_prefer() to resolve common conflicts.
  17. trees_df <- trees_df[1:2000,]
  18. set.seed(123)
  19. trees_split <- initial_split(trees_df, strata = legal_status)
  20. trees_train <- training(trees_split)
  21. trees_test <- testing(trees_split)
  1. tree_rec <- recipe(legal_status ~ ., data = trees_train) %>%
  2. update_role(tree_id, new_role = "ID") %>%
  3. step_other(species, caretaker, threshold = 0.01) %>%
  4. step_other(site_info, threshold = 0.005) %>%
  5. step_dummy(all_nominal(), -all_outcomes()) %>%
  6. step_date(date, features = c("year")) %>%
  7. step_rm(date) %>%
  8. step_downsample(legal_status)
  9. tree_prep <- prep(tree_rec)
  10. juiced <- juice(tree_prep)
  1. tune_spec <- rand_forest(
  2. mtry = tune(),
  3. trees = 1000,
  4. min_n = tune()
  5. ) %>%
  6. set_mode("classification") %>%
  7. set_engine("ranger")
  1. tune_wf <- workflow() %>%
  2. add_recipe(tree_rec) %>%
  3. add_model(tune_spec)

调整参数

  1. set.seed(234)
  2. trees_folds <- vfold_cv(trees_train,v = 5)
  1. doParallel::registerDoParallel()
  2. set.seed(345)
  3. tune_res <- tune_grid(
  4. tune_wf,
  5. resamples = trees_folds,
  6. grid = 5
  7. )
  8. ## i Creating pre-processing data to finalize unknown parameter: mtry
  9. tune_res
  10. ## # Tuning results
  11. ## # 5-fold cross-validation
  12. ## # A tibble: 5 x 4
  13. ## splits id .metrics .notes
  14. ## <list> <chr> <list> <list>
  15. ## 1 <split [1200/301]> Fold1 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
  16. ## 2 <split [1201/300]> Fold2 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
  17. ## 3 <split [1201/300]> Fold3 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
  18. ## 4 <split [1201/300]> Fold4 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
  19. ## 5 <split [1201/300]> Fold5 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
  1. tune_res %>%
  2. collect_metrics() %>%
  3. filter(.metric == "roc_auc") %>%
  4. select(mean, min_n, mtry) %>%
  5. pivot_longer(min_n:mtry,
  6. values_to = "value",
  7. names_to = "parameter"
  8. ) %>%
  9. ggplot(aes(value, mean, color = parameter)) +
  10. geom_point(show.legend = FALSE) +
  11. facet_wrap(~parameter, scales = "free_x") +
  12. labs(x = NULL, y = "AUC")

tidymodels-exercise-05 - 图3

  1. rf_grid <- grid_regular(
  2. mtry(range = c(20, 25)),
  3. min_n(range = c(2, 5)),
  4. levels = 5
  5. )
  6. rf_grid
  7. ## # A tibble: 20 x 2
  8. ## mtry min_n
  9. ## <int> <int>
  10. ## 1 20 2
  11. ## 2 21 2
  12. ## 3 22 2
  13. ## 4 23 2
  14. ## 5 25 2
  15. ## 6 20 3
  16. ## 7 21 3
  17. ## 8 22 3
  18. ## 9 23 3
  19. ## 10 25 3
  20. ## 11 20 4
  21. ## 12 21 4
  22. ## 13 22 4
  23. ## 14 23 4
  24. ## 15 25 4
  25. ## 16 20 5
  26. ## 17 21 5
  27. ## 18 22 5
  28. ## 19 23 5
  29. ## 20 25 5
  1. set.seed(456)
  2. regular_res <- tune_grid(
  3. tune_wf,
  4. resamples = trees_folds,
  5. grid = rf_grid
  6. )
  7. regular_res
  8. ## # Tuning results
  9. ## # 5-fold cross-validation
  10. ## # A tibble: 5 x 4
  11. ## splits id .metrics .notes
  12. ## <list> <chr> <list> <list>
  13. ## 1 <split [1200/301]> Fold1 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
  14. ## 2 <split [1201/300]> Fold2 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
  15. ## 3 <split [1201/300]> Fold3 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
  16. ## 4 <split [1201/300]> Fold4 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
  17. ## 5 <split [1201/300]> Fold5 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>

评价模型

  1. regular_res %>%
  2. collect_metrics() %>%
  3. filter(.metric == "roc_auc") %>%
  4. mutate(min_n = factor(min_n)) %>%
  5. ggplot(aes(mtry, mean, color = min_n)) +
  6. geom_line(alpha = 0.5, size = 1.5) +
  7. geom_point() +
  8. labs(y = "AUC")

tidymodels-exercise-05 - 图4

  1. best_auc <- select_best(regular_res, "roc_auc")
  2. final_rf <- finalize_model(
  3. tune_spec,
  4. best_auc
  5. )
  6. final_rf
  7. ## Random Forest Model Specification (classification)
  8. ##
  9. ## Main Arguments:
  10. ## mtry = 23
  11. ## trees = 1000
  12. ## min_n = 3
  13. ##
  14. ## Computational engine: ranger
  1. library(vip)
  2. ##
  3. ## Attaching package: 'vip'
  4. ## The following object is masked from 'package:utils':
  5. ##
  6. ## vi
  7. final_rf %>%
  8. set_engine("ranger", importance = "permutation") %>%
  9. fit(legal_status ~ .,
  10. data = juice(tree_prep) %>% select(-tree_id)
  11. ) %>%
  12. vip(geom = "point")

tidymodels-exercise-05 - 图5

  1. final_wf <- tune_wf %>%
  2. finalize_workflow(best_auc)
  3. final_res <- final_wf %>%
  4. last_fit(trees_split)
  5. final_res %>%
  6. collect_metrics()
  7. ## # A tibble: 2 x 4
  8. ## .metric .estimator .estimate .config
  9. ## <chr> <chr> <dbl> <chr>
  10. ## 1 accuracy binary 0.780 Preprocessor1_Model1
  11. ## 2 roc_auc binary 0.901 Preprocessor1_Model1
  1. predict(final_res$.workflow[[1]], trees_train[1:10,])
  2. ## # A tibble: 10 x 1
  3. ## .pred_class
  4. ## <fct>
  5. ## 1 DPW Maintained
  6. ## 2 DPW Maintained
  7. ## 3 DPW Maintained
  8. ## 4 DPW Maintained
  9. ## 5 DPW Maintained
  10. ## 6 DPW Maintained
  11. ## 7 DPW Maintained
  12. ## 8 Other
  13. ## 9 Other
  14. ## 10 DPW Maintained
  1. final_res1 <- last_fit(final_rf, legal_status ~ ., trees_split)
  2. predict(final_res1$.workflow[[1]], trees_train[1:10,])
  3. ## # A tibble: 10 x 1
  4. ## .pred_class
  5. ## <fct>
  6. ## 1 DPW Maintained
  7. ## 2 DPW Maintained
  8. ## 3 DPW Maintained
  9. ## 4 DPW Maintained
  10. ## 5 DPW Maintained
  11. ## 6 DPW Maintained
  12. ## 7 DPW Maintained
  13. ## 8 Other
  14. ## 9 Other
  15. ## 10 DPW Maintained
  1. final_res2 <- fit(final_rf, legal_status ~ ., trees_train)
  2. predict(final_res2, trees_train[1:10,])
  3. ## # A tibble: 10 x 1
  4. ## .pred_class
  5. ## <fct>
  6. ## 1 DPW Maintained
  7. ## 2 DPW Maintained
  8. ## 3 DPW Maintained
  9. ## 4 DPW Maintained
  10. ## 5 DPW Maintained
  11. ## 6 DPW Maintained
  12. ## 7 DPW Maintained
  13. ## 8 Other
  14. ## 9 Other
  15. ## 10 DPW Maintained
  1. final_res %>%
  2. collect_predictions() %>%
  3. mutate(correct = case_when(
  4. legal_status == .pred_class ~ "Correct",
  5. TRUE ~ "Incorrect"
  6. )) %>%
  7. bind_cols(trees_test) %>%
  8. ggplot(aes(longitude, latitude, color = correct)) +
  9. geom_point(size = 0.5, alpha = 0.5) +
  10. labs(color = NULL) +
  11. scale_color_manual(values = c("gray80", "darkred"))
  12. ## New names:
  13. ## * legal_status -> legal_status...6
  14. ## * legal_status -> legal_status...10

tidymodels-exercise-05 - 图6