title: “tidymodels-exercise-08”
author: “liyue”
date: “2021/7/31”
output: html_document

使用tidymodels预测著名的泰坦尼克号数据集!

数据探索

首先加载数据

  1. rm(list = ls())
  2. library(tidyverse)
  3. ## -- Attaching packages ------------------------------------------ tidyverse 1.3.1 --
  4. ## v ggplot2 3.3.5 v purrr 0.3.4
  5. ## v tibble 3.1.2 v dplyr 1.0.7
  6. ## v tidyr 1.1.3 v stringr 1.4.0
  7. ## v readr 1.4.0 v forcats 0.5.1
  8. ## -- Conflicts --------------------------------------------- tidyverse_conflicts() --
  9. ## x dplyr::filter() masks stats::filter()
  10. ## x dplyr::lag() masks stats::lag()
  11. titanic <- read_csv('../datasets/titanic.train.csv')
  12. ##
  13. ## -- Column specification -----------------------------------------------------------
  14. ## cols(
  15. ## .default = col_character(),
  16. ## PassengerId = col_double(),
  17. ## Survived = col_double(),
  18. ## Pclass = col_double(),
  19. ## Age = col_double(),
  20. ## SibSp = col_double(),
  21. ## Parch = col_double(),
  22. ## Fare = col_double(),
  23. ## WikiId = col_double(),
  24. ## Age_wiki = col_double(),
  25. ## Class = col_double()
  26. ## )
  27. ## i Use `spec()` for the full column specifications.
  28. titanic
  29. ## # A tibble: 891 x 21
  30. ## PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin
  31. ## <dbl> <dbl> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <chr> <dbl> <chr>
  32. ## 1 1 0 3 Braun~ male 22 1 0 A/5 2~ 7.25 <NA>
  33. ## 2 2 1 1 Cumin~ fema~ 38 1 0 PC 17~ 71.3 C85
  34. ## 3 3 1 3 Heikk~ fema~ 26 0 0 STON/~ 7.92 <NA>
  35. ## 4 4 1 1 Futre~ fema~ 35 1 0 113803 53.1 C123
  36. ## 5 5 0 3 Allen~ male 35 0 0 373450 8.05 <NA>
  37. ## 6 6 0 3 Moran~ male NA 0 0 330877 8.46 <NA>
  38. ## 7 7 0 1 McCar~ male 54 0 0 17463 51.9 E46
  39. ## 8 8 0 3 Palss~ male 2 3 1 349909 21.1 <NA>
  40. ## 9 9 1 3 Johns~ fema~ 27 0 2 347742 11.1 <NA>
  41. ## 10 10 1 2 Nasse~ fema~ 14 1 0 237736 30.1 <NA>
  42. ## # ... with 881 more rows, and 10 more variables: Embarked <chr>, WikiId <dbl>,
  43. ## # Name_wiki <chr>, Age_wiki <dbl>, Hometown <chr>, Boarded <chr>,
  44. ## # Destination <chr>, Lifeboat <chr>, Body <chr>, Class <dbl>

查看一下数据的情况

  1. skimr::skim(titanic)

Table: Data summary

Name titanic
Number of rows 891
Number of columns 21
_
Column type frequency:
character 11
numeric 10
__
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
Name 0 1.00 12 82 0 891 0
Sex 0 1.00 4 6 0 2 0
Ticket 0 1.00 3 18 0 681 0
Cabin 687 0.23 1 15 0 147 0
Embarked 2 1.00 1 1 0 3 0
Name_wiki 2 1.00 12 69 0 889 0
Hometown 2 1.00 6 44 0 437 0
Boarded 2 1.00 7 11 0 4 0
Destination 2 1.00 10 39 0 234 0
Lifeboat 546 0.39 1 3 0 22 0
Body 804 0.10 3 8 0 87 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
PassengerId 0 1.0 446.00 257.35 1.00 223.50 446.00 668.5 891.00 ▇▇▇▇▇
Survived 0 1.0 0.38 0.49 0.00 0.00 0.00 1.0 1.00 ▇▁▁▁▅
Pclass 0 1.0 2.31 0.84 1.00 2.00 3.00 3.0 3.00 ▃▁▃▁▇
Age 177 0.8 29.70 14.53 0.42 20.12 28.00 38.0 80.00 ▂▇▅▂▁
SibSp 0 1.0 0.52 1.10 0.00 0.00 0.00 1.0 8.00 ▇▁▁▁▁
Parch 0 1.0 0.38 0.81 0.00 0.00 0.00 0.0 6.00 ▇▁▁▁▁
Fare 0 1.0 32.20 49.69 0.00 7.91 14.45 31.0 512.33 ▇▁▁▁▁
WikiId 2 1.0 665.47 380.80 1.00 336.00 672.00 996.0 1314.00 ▇▇▇▇▇
Age_wiki 4 1.0 29.32 13.93 0.42 20.00 28.00 38.0 74.00 ▂▇▅▂▁
Class 2 1.0 2.31 0.84 1.00 2.00 3.00 3.0 3.00 ▃▁▃▁▇

可以看到共有891行,21个变量,有很多缺失值,还有很多字符型变量,需要进行预处理。主要从以下几个方面入手:

  • 处理缺失值;
  • 字符型变量变为因子型;
  • 删掉对结果影响不大的变量

下面是代码

  1. titanic_df <- titanic %>%
  2. select(-c(13:21),-Name,-Ticket,-Cabin,-PassengerId) %>%
  3. drop_na(Embarked) %>%
  4. mutate_if(is.character, factor) %>%
  5. mutate(Survived = factor(Survived))
  6. skimr::skim(titanic_df)

Table: Data summary

Name titanic_df
Number of rows 889
Number of columns 8
_
Column type frequency:
factor 3
numeric 5
__
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
Survived 0 1 FALSE 2 0: 549, 1: 340
Sex 0 1 FALSE 2 mal: 577, fem: 312
Embarked 0 1 FALSE 3 S: 644, C: 168, Q: 77

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Pclass 0 1.0 2.31 0.83 1.00 2.0 3.00 3 3.00 ▃▁▃▁▇
Age 177 0.8 29.64 14.49 0.42 20.0 28.00 38 80.00 ▂▇▅▂▁
SibSp 0 1.0 0.52 1.10 0.00 0.0 0.00 1 8.00 ▇▁▁▁▁
Parch 0 1.0 0.38 0.81 0.00 0.0 0.00 0 6.00 ▇▁▁▁▁
Fare 0 1.0 32.10 49.70 0.00 7.9 14.45 31 512.33 ▇▁▁▁▁

这样只留下了7个预测变量,1个结果变量,结果变量为因子型,存活或者死亡;预测变量为因子型和数值型,并删掉了很多不太相关的变量(这样并不严谨)。

简单再看下数据情况

  1. # 存活与否与性别的关系
  2. ggplot(titanic_df, aes(x=Sex, fill=Survived))+
  3. geom_bar()+
  4. theme_minimal()+
  5. labs(x='',y='number')

tidymodels-exercise-08 - 图1

可以看到男性大部分都活下来了,女性大部分都没能活下来。。。

  1. # 存活与否与舱位的关系
  2. ggplot(titanic_df, aes(x=Embarked, fill=Survived))+
  3. geom_bar()+
  4. theme_minimal()+
  5. labs(x='',y='number')

tidymodels-exercise-08 - 图2

数据建模

  1. library(tidymodels)
  2. ## Registered S3 method overwritten by 'tune':
  3. ## method from
  4. ## required_pkgs.model_spec parsnip
  5. ## -- Attaching packages ----------------------------------------- tidymodels 0.1.3 --
  6. ## v broom 0.7.8 v rsample 0.1.0
  7. ## v dials 0.0.9 v tune 0.1.5
  8. ## v infer 0.5.4 v workflows 0.2.2
  9. ## v modeldata 0.1.0 v workflowsets 0.0.2
  10. ## v parsnip 0.1.6 v yardstick 0.0.8
  11. ## v recipes 0.1.16
  12. ## -- Conflicts -------------------------------------------- tidymodels_conflicts() --
  13. ## x scales::discard() masks purrr::discard()
  14. ## x dplyr::filter() masks stats::filter()
  15. ## x recipes::fixed() masks stringr::fixed()
  16. ## x dplyr::lag() masks stats::lag()
  17. ## x yardstick::spec() masks readr::spec()
  18. ## x recipes::step() masks stats::step()
  19. ## * Use tidymodels_prefer() to resolve common conflicts.
  20. tidymodels_prefer()
  21. titanic_split <- initial_split(titanic_df)
  22. titanic_train <- training(titanic_split)
  23. titanic_test <- testing(titanic_split)
  24. titanic_boot <- bootstraps(titanic_train)
  25. #library(usemodels)
  26. #use_ranger(Survived ~ ., data = titanic_train)
  27. ranger_rec <- recipe(Survived ~ ., data = titanic_train) %>%
  28. step_impute_mean(Age)
  29. ranger_spec <- rand_forest(mode = "classification", mtry = tune(),
  30. trees = 1000,
  31. min_n = tune()) %>%
  32. set_engine('ranger')
  33. ranger_wf <- workflow() %>%
  34. add_recipe(ranger_rec) %>%
  35. add_model(ranger_spec)
  36. set.seed(123)
  37. doParallel::registerDoParallel()
  38. ranger_tune <- tune_grid(
  39. ranger_wf,
  40. titanic_boot,
  41. grid = 10,
  42. control = control_grid(verbose = TRUE,save_pred = TRUE,
  43. parallel_over = 'everything')
  44. )
  45. ## i Creating pre-processing data to finalize unknown parameter: mtry
  46. ranger_tune
  47. ## # Tuning results
  48. ## # Bootstrap sampling
  49. ## # A tibble: 25 x 5
  50. ## splits id .metrics .notes .predictions
  51. ## <list> <chr> <list> <list> <list>
  52. ## 1 <split [666/25~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,500 x ~
  53. ## 2 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,480 x ~
  54. ## 3 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,470 x ~
  55. ## 4 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,490 x ~
  56. ## 5 <split [666/26~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,600 x ~
  57. ## 6 <split [666/23~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,340 x ~
  58. ## 7 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,460 x ~
  59. ## 8 <split [666/25~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,590 x ~
  60. ## 9 <split [666/23~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,340 x ~
  61. ## 10 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,490 x ~
  62. ## # ... with 15 more rows

看看结果如何

  1. collect_metrics(ranger_tune)
  2. ## # A tibble: 20 x 8
  3. ## mtry min_n .metric .estimator mean n std_err .config
  4. ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
  5. ## 1 2 30 accuracy binary 0.817 25 0.00393 Preprocessor1_Model01
  6. ## 2 2 30 roc_auc binary 0.857 25 0.00468 Preprocessor1_Model01
  7. ## 3 6 3 accuracy binary 0.806 25 0.00408 Preprocessor1_Model02
  8. ## 4 6 3 roc_auc binary 0.845 25 0.00510 Preprocessor1_Model02
  9. ## 5 2 39 accuracy binary 0.816 25 0.00383 Preprocessor1_Model03
  10. ## 6 2 39 roc_auc binary 0.857 25 0.00462 Preprocessor1_Model03
  11. ## 7 3 27 accuracy binary 0.818 25 0.00386 Preprocessor1_Model04
  12. ## 8 3 27 roc_auc binary 0.855 25 0.00478 Preprocessor1_Model04
  13. ## 9 6 36 accuracy binary 0.813 25 0.00424 Preprocessor1_Model05
  14. ## 10 6 36 roc_auc binary 0.852 25 0.00469 Preprocessor1_Model05
  15. ## 11 3 20 accuracy binary 0.822 25 0.00434 Preprocessor1_Model06
  16. ## 12 3 20 roc_auc binary 0.854 25 0.00488 Preprocessor1_Model06
  17. ## 13 4 24 accuracy binary 0.817 25 0.00428 Preprocessor1_Model07
  18. ## 14 4 24 roc_auc binary 0.853 25 0.00475 Preprocessor1_Model07
  19. ## 15 6 12 accuracy binary 0.818 25 0.00366 Preprocessor1_Model08
  20. ## 16 6 12 roc_auc binary 0.850 25 0.00496 Preprocessor1_Model08
  21. ## 17 5 15 accuracy binary 0.819 25 0.00425 Preprocessor1_Model09
  22. ## 18 5 15 roc_auc binary 0.852 25 0.00483 Preprocessor1_Model09
  23. ## 19 1 7 accuracy binary 0.812 25 0.00447 Preprocessor1_Model10
  24. ## 20 1 7 roc_auc binary 0.856 25 0.00437 Preprocessor1_Model10
  25. autoplot(ranger_tune)

tidymodels-exercise-08 - 图3

根据结果选择最合适的参数

  1. ranger_grid <- grid_regular(
  2. mtry(range = c(1,3)),
  3. min_n(range = c(5,10)),
  4. levels = 5
  5. )
  6. ranger_grid
  7. ## # A tibble: 15 x 2
  8. ## mtry min_n
  9. ## <int> <int>
  10. ## 1 1 5
  11. ## 2 2 5
  12. ## 3 3 5
  13. ## 4 1 6
  14. ## 5 2 6
  15. ## 6 3 6
  16. ## 7 1 7
  17. ## 8 2 7
  18. ## 9 3 7
  19. ## 10 1 8
  20. ## 11 2 8
  21. ## 12 3 8
  22. ## 13 1 10
  23. ## 14 2 10
  24. ## 15 3 10

再训练一次

  1. set.seed(1234)
  2. doParallel::registerDoParallel()
  3. regular_res <- tune_grid(
  4. ranger_wf,
  5. titanic_boot,
  6. grid = ranger_grid,
  7. control = control_grid(verbose = TRUE,save_pred = TRUE,
  8. parallel_over = 'everything')
  9. )
  10. regular_res
  11. ## # Tuning results
  12. ## # Bootstrap sampling
  13. ## # A tibble: 25 x 5
  14. ## splits id .metrics .notes .predictions
  15. ## <list> <chr> <list> <list> <list>
  16. ## 1 <split [666/25~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,750 x ~
  17. ## 2 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,720 x ~
  18. ## 3 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,705 x ~
  19. ## 4 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,735 x ~
  20. ## 5 <split [666/26~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,900 x ~
  21. ## 6 <split [666/23~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,510 x ~
  22. ## 7 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,690 x ~
  23. ## 8 <split [666/25~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,885 x ~
  24. ## 9 <split [666/23~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,510 x ~
  25. ## 10 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,735 x ~
  26. ## # ... with 15 more rows

再看看结果

  1. collect_metrics(regular_res)
  2. ## # A tibble: 30 x 8
  3. ## mtry min_n .metric .estimator mean n std_err .config
  4. ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
  5. ## 1 1 5 accuracy binary 0.810 25 0.00406 Preprocessor1_Model01
  6. ## 2 1 5 roc_auc binary 0.856 25 0.00437 Preprocessor1_Model01
  7. ## 3 2 5 accuracy binary 0.823 25 0.00415 Preprocessor1_Model02
  8. ## 4 2 5 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model02
  9. ## 5 3 5 accuracy binary 0.818 25 0.00440 Preprocessor1_Model03
  10. ## 6 3 5 roc_auc binary 0.851 25 0.00504 Preprocessor1_Model03
  11. ## 7 1 6 accuracy binary 0.811 25 0.00408 Preprocessor1_Model04
  12. ## 8 1 6 roc_auc binary 0.856 25 0.00435 Preprocessor1_Model04
  13. ## 9 2 6 accuracy binary 0.824 25 0.00416 Preprocessor1_Model05
  14. ## 10 2 6 roc_auc binary 0.856 25 0.00489 Preprocessor1_Model05
  15. ## # ... with 20 more rows
  16. autoplot(regular_res)

tidymodels-exercise-08 - 图4

  1. show_best(regular_res,metric = 'roc_auc')
  2. ## # A tibble: 5 x 8
  3. ## mtry min_n .metric .estimator mean n std_err .config
  4. ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
  5. ## 1 1 5 roc_auc binary 0.856 25 0.00437 Preprocessor1_Model01
  6. ## 2 2 6 roc_auc binary 0.856 25 0.00489 Preprocessor1_Model05
  7. ## 3 2 8 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model11
  8. ## 4 2 5 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model02
  9. ## 5 2 10 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model14
  10. select_best(regular_res,metric = 'roc_auc')
  11. ## # A tibble: 1 x 3
  12. ## mtry min_n .config
  13. ## <int> <int> <chr>
  14. ## 1 1 5 Preprocessor1_Model01

看看AUC曲线

  1. collect_predictions(regular_res) %>% roc_auc(Survived,.pred_0)
  2. ## # A tibble: 1 x 3
  3. ## .metric .estimator .estimate
  4. ## <chr> <chr> <dbl>
  5. ## 1 roc_auc binary 0.848
  6. collect_predictions(regular_res) %>% roc_curve(Survived,.pred_0) %>%
  7. ggplot(aes(x = 1-specificity, y=sensitivity))+
  8. geom_line(size=1.5)+
  9. geom_abline(lty=2, size=1.2,color='gray80',)+
  10. theme_minimal()

tidymodels-exercise-08 - 图5

结果评价

  1. final_wf <- finalize_workflow(ranger_wf,select_best(regular_res,'roc_auc'))
  2. final_wf
  3. ## == Workflow =======================================================================
  4. ## Preprocessor: Recipe
  5. ## Model: rand_forest()
  6. ##
  7. ## -- Preprocessor -------------------------------------------------------------------
  8. ## 1 Recipe Step
  9. ##
  10. ## * step_impute_mean()
  11. ##
  12. ## -- Model --------------------------------------------------------------------------
  13. ## Random Forest Model Specification (classification)
  14. ##
  15. ## Main Arguments:
  16. ## mtry = 1
  17. ## trees = 1000
  18. ## min_n = 5
  19. ##
  20. ## Computational engine: ranger
  21. final_res <- last_fit(final_wf, titanic_split)
  22. collect_metrics(final_res)
  23. ## # A tibble: 2 x 4
  24. ## .metric .estimator .estimate .config
  25. ## <chr> <chr> <dbl> <chr>
  26. ## 1 accuracy binary 0.825 Preprocessor1_Model1
  27. ## 2 roc_auc binary 0.884 Preprocessor1_Model1
  28. collect_predictions(final_res) %>% roc_curve(Survived,.pred_0) %>% autoplot()

tidymodels-exercise-08 - 图6

  1. collect_predictions(final_res) %>% roc_auc(Survived, .pred_0)
  2. ## # A tibble: 1 x 3
  3. ## .metric .estimator .estimate
  4. ## <chr> <chr> <dbl>
  5. ## 1 roc_auc binary 0.884
  6. collect_predictions(final_res) %>% roc_curve(Survived,.pred_0) %>%
  7. ggplot(aes(x = 1-specificity, y=sensitivity))+
  8. geom_line(size=1.5)+
  9. geom_abline(lty=2, size=1.2,color='gray80',)+
  10. theme_minimal()

tidymodels-exercise-08 - 图7

预测新数据

  1. test <- read_csv('../datasets/titanic.test.csv')
  2. ##
  3. ## -- Column specification -----------------------------------------------------------
  4. ## cols(
  5. ## .default = col_character(),
  6. ## PassengerId = col_double(),
  7. ## Pclass = col_double(),
  8. ## Age = col_double(),
  9. ## SibSp = col_double(),
  10. ## Parch = col_double(),
  11. ## Fare = col_double(),
  12. ## WikiId = col_double(),
  13. ## Age_wiki = col_double(),
  14. ## Class = col_double()
  15. ## )
  16. ## i Use `spec()` for the full column specifications.
  17. test
  18. ## # A tibble: 418 x 20
  19. ## PassengerId Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
  20. ## <dbl> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <chr> <dbl> <chr> <chr>
  21. ## 1 892 3 Kelly~ male 34.5 0 0 330911 7.83 <NA> Q
  22. ## 2 893 3 Wilke~ fema~ 47 1 0 363272 7 <NA> S
  23. ## 3 894 2 Myles~ male 62 0 0 240276 9.69 <NA> Q
  24. ## 4 895 3 Wirz,~ male 27 0 0 315154 8.66 <NA> S
  25. ## 5 896 3 Hirvo~ fema~ 22 1 1 31012~ 12.3 <NA> S
  26. ## 6 897 3 Svens~ male 14 0 0 7538 9.22 <NA> S
  27. ## 7 898 3 Conno~ fema~ 30 0 0 330972 7.63 <NA> Q
  28. ## 8 899 2 Caldw~ male 26 1 1 248738 29 <NA> S
  29. ## 9 900 3 Abrah~ fema~ 18 0 0 2657 7.23 <NA> C
  30. ## 10 901 3 Davie~ male 21 2 0 A/4 4~ 24.2 <NA> S
  31. ## # ... with 408 more rows, and 9 more variables: WikiId <dbl>, Name_wiki <chr>,
  32. ## # Age_wiki <dbl>, Hometown <chr>, Boarded <chr>, Destination <chr>,
  33. ## # Lifeboat <chr>, Body <chr>, Class <dbl>
  34. test_df <- test %>%
  35. select(-c(12:20),-Name,-Ticket,-Cabin,-PassengerId) %>%
  36. drop_na() %>%
  37. mutate_if(is.character, factor)
  38. test_res <- predict(final_res$.workflow[[1]],new_data = test_df)
  39. test_res
  40. ## # A tibble: 331 x 1
  41. ## .pred_class
  42. ## <fct>
  43. ## 1 0
  44. ## 2 0
  45. ## 3 0
  46. ## 4 0
  47. ## 5 0
  48. ## 6 0
  49. ## 7 0
  50. ## 8 0
  51. ## 9 1
  52. ## 10 0
  53. ## # ... with 321 more rows

得到的这个结果可以提交到kaggle上面。