title: “tidymodels-exercise-08”
author: “liyue”
date: “2021/7/31”
output: html_document
使用tidymodels预测著名的泰坦尼克号数据集!
数据探索
首先加载数据
rm(list = ls())library(tidyverse)## -- Attaching packages ------------------------------------------ tidyverse 1.3.1 --## v ggplot2 3.3.5 v purrr 0.3.4## v tibble 3.1.2 v dplyr 1.0.7## v tidyr 1.1.3 v stringr 1.4.0## v readr 1.4.0 v forcats 0.5.1## -- Conflicts --------------------------------------------- tidyverse_conflicts() --## x dplyr::filter() masks stats::filter()## x dplyr::lag() masks stats::lag()titanic <- read_csv('../datasets/titanic.train.csv')#### -- Column specification -----------------------------------------------------------## cols(## .default = col_character(),## PassengerId = col_double(),## Survived = col_double(),## Pclass = col_double(),## Age = col_double(),## SibSp = col_double(),## Parch = col_double(),## Fare = col_double(),## WikiId = col_double(),## Age_wiki = col_double(),## Class = col_double()## )## i Use `spec()` for the full column specifications.titanic## # A tibble: 891 x 21## PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin## <dbl> <dbl> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <chr> <dbl> <chr>## 1 1 0 3 Braun~ male 22 1 0 A/5 2~ 7.25 <NA>## 2 2 1 1 Cumin~ fema~ 38 1 0 PC 17~ 71.3 C85## 3 3 1 3 Heikk~ fema~ 26 0 0 STON/~ 7.92 <NA>## 4 4 1 1 Futre~ fema~ 35 1 0 113803 53.1 C123## 5 5 0 3 Allen~ male 35 0 0 373450 8.05 <NA>## 6 6 0 3 Moran~ male NA 0 0 330877 8.46 <NA>## 7 7 0 1 McCar~ male 54 0 0 17463 51.9 E46## 8 8 0 3 Palss~ male 2 3 1 349909 21.1 <NA>## 9 9 1 3 Johns~ fema~ 27 0 2 347742 11.1 <NA>## 10 10 1 2 Nasse~ fema~ 14 1 0 237736 30.1 <NA>## # ... with 881 more rows, and 10 more variables: Embarked <chr>, WikiId <dbl>,## # Name_wiki <chr>, Age_wiki <dbl>, Hometown <chr>, Boarded <chr>,## # Destination <chr>, Lifeboat <chr>, Body <chr>, Class <dbl>
查看一下数据的情况
skimr::skim(titanic)
Table: Data summary
| Name | titanic |
| Number of rows | 891 |
| Number of columns | 21 |
| _ | |
| Column type frequency: | |
| character | 11 |
| numeric | 10 |
| __ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| Name | 0 | 1.00 | 12 | 82 | 0 | 891 | 0 |
| Sex | 0 | 1.00 | 4 | 6 | 0 | 2 | 0 |
| Ticket | 0 | 1.00 | 3 | 18 | 0 | 681 | 0 |
| Cabin | 687 | 0.23 | 1 | 15 | 0 | 147 | 0 |
| Embarked | 2 | 1.00 | 1 | 1 | 0 | 3 | 0 |
| Name_wiki | 2 | 1.00 | 12 | 69 | 0 | 889 | 0 |
| Hometown | 2 | 1.00 | 6 | 44 | 0 | 437 | 0 |
| Boarded | 2 | 1.00 | 7 | 11 | 0 | 4 | 0 |
| Destination | 2 | 1.00 | 10 | 39 | 0 | 234 | 0 |
| Lifeboat | 546 | 0.39 | 1 | 3 | 0 | 22 | 0 |
| Body | 804 | 0.10 | 3 | 8 | 0 | 87 | 0 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| PassengerId | 0 | 1.0 | 446.00 | 257.35 | 1.00 | 223.50 | 446.00 | 668.5 | 891.00 | ▇▇▇▇▇ |
| Survived | 0 | 1.0 | 0.38 | 0.49 | 0.00 | 0.00 | 0.00 | 1.0 | 1.00 | ▇▁▁▁▅ |
| Pclass | 0 | 1.0 | 2.31 | 0.84 | 1.00 | 2.00 | 3.00 | 3.0 | 3.00 | ▃▁▃▁▇ |
| Age | 177 | 0.8 | 29.70 | 14.53 | 0.42 | 20.12 | 28.00 | 38.0 | 80.00 | ▂▇▅▂▁ |
| SibSp | 0 | 1.0 | 0.52 | 1.10 | 0.00 | 0.00 | 0.00 | 1.0 | 8.00 | ▇▁▁▁▁ |
| Parch | 0 | 1.0 | 0.38 | 0.81 | 0.00 | 0.00 | 0.00 | 0.0 | 6.00 | ▇▁▁▁▁ |
| Fare | 0 | 1.0 | 32.20 | 49.69 | 0.00 | 7.91 | 14.45 | 31.0 | 512.33 | ▇▁▁▁▁ |
| WikiId | 2 | 1.0 | 665.47 | 380.80 | 1.00 | 336.00 | 672.00 | 996.0 | 1314.00 | ▇▇▇▇▇ |
| Age_wiki | 4 | 1.0 | 29.32 | 13.93 | 0.42 | 20.00 | 28.00 | 38.0 | 74.00 | ▂▇▅▂▁ |
| Class | 2 | 1.0 | 2.31 | 0.84 | 1.00 | 2.00 | 3.00 | 3.0 | 3.00 | ▃▁▃▁▇ |
可以看到共有891行,21个变量,有很多缺失值,还有很多字符型变量,需要进行预处理。主要从以下几个方面入手:
- 处理缺失值;
- 字符型变量变为因子型;
- 删掉对结果影响不大的变量
下面是代码
titanic_df <- titanic %>%select(-c(13:21),-Name,-Ticket,-Cabin,-PassengerId) %>%drop_na(Embarked) %>%mutate_if(is.character, factor) %>%mutate(Survived = factor(Survived))skimr::skim(titanic_df)
Table: Data summary
| Name | titanic_df |
| Number of rows | 889 |
| Number of columns | 8 |
| _ | |
| Column type frequency: | |
| factor | 3 |
| numeric | 5 |
| __ | |
| Group variables | None |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| Survived | 0 | 1 | FALSE | 2 | 0: 549, 1: 340 |
| Sex | 0 | 1 | FALSE | 2 | mal: 577, fem: 312 |
| Embarked | 0 | 1 | FALSE | 3 | S: 644, C: 168, Q: 77 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| Pclass | 0 | 1.0 | 2.31 | 0.83 | 1.00 | 2.0 | 3.00 | 3 | 3.00 | ▃▁▃▁▇ |
| Age | 177 | 0.8 | 29.64 | 14.49 | 0.42 | 20.0 | 28.00 | 38 | 80.00 | ▂▇▅▂▁ |
| SibSp | 0 | 1.0 | 0.52 | 1.10 | 0.00 | 0.0 | 0.00 | 1 | 8.00 | ▇▁▁▁▁ |
| Parch | 0 | 1.0 | 0.38 | 0.81 | 0.00 | 0.0 | 0.00 | 0 | 6.00 | ▇▁▁▁▁ |
| Fare | 0 | 1.0 | 32.10 | 49.70 | 0.00 | 7.9 | 14.45 | 31 | 512.33 | ▇▁▁▁▁ |
这样只留下了7个预测变量,1个结果变量,结果变量为因子型,存活或者死亡;预测变量为因子型和数值型,并删掉了很多不太相关的变量(这样并不严谨)。
简单再看下数据情况
# 存活与否与性别的关系ggplot(titanic_df, aes(x=Sex, fill=Survived))+geom_bar()+theme_minimal()+labs(x='',y='number')

可以看到男性大部分都活下来了,女性大部分都没能活下来。。。
# 存活与否与舱位的关系ggplot(titanic_df, aes(x=Embarked, fill=Survived))+geom_bar()+theme_minimal()+labs(x='',y='number')

数据建模
library(tidymodels)## Registered S3 method overwritten by 'tune':## method from## required_pkgs.model_spec parsnip## -- Attaching packages ----------------------------------------- tidymodels 0.1.3 --## v broom 0.7.8 v rsample 0.1.0## v dials 0.0.9 v tune 0.1.5## v infer 0.5.4 v workflows 0.2.2## v modeldata 0.1.0 v workflowsets 0.0.2## v parsnip 0.1.6 v yardstick 0.0.8## v recipes 0.1.16## -- Conflicts -------------------------------------------- tidymodels_conflicts() --## x scales::discard() masks purrr::discard()## x dplyr::filter() masks stats::filter()## x recipes::fixed() masks stringr::fixed()## x dplyr::lag() masks stats::lag()## x yardstick::spec() masks readr::spec()## x recipes::step() masks stats::step()## * Use tidymodels_prefer() to resolve common conflicts.tidymodels_prefer()titanic_split <- initial_split(titanic_df)titanic_train <- training(titanic_split)titanic_test <- testing(titanic_split)titanic_boot <- bootstraps(titanic_train)#library(usemodels)#use_ranger(Survived ~ ., data = titanic_train)ranger_rec <- recipe(Survived ~ ., data = titanic_train) %>%step_impute_mean(Age)ranger_spec <- rand_forest(mode = "classification", mtry = tune(),trees = 1000,min_n = tune()) %>%set_engine('ranger')ranger_wf <- workflow() %>%add_recipe(ranger_rec) %>%add_model(ranger_spec)set.seed(123)doParallel::registerDoParallel()ranger_tune <- tune_grid(ranger_wf,titanic_boot,grid = 10,control = control_grid(verbose = TRUE,save_pred = TRUE,parallel_over = 'everything'))## i Creating pre-processing data to finalize unknown parameter: mtryranger_tune## # Tuning results## # Bootstrap sampling## # A tibble: 25 x 5## splits id .metrics .notes .predictions## <list> <chr> <list> <list> <list>## 1 <split [666/25~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,500 x ~## 2 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,480 x ~## 3 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,470 x ~## 4 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,490 x ~## 5 <split [666/26~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,600 x ~## 6 <split [666/23~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,340 x ~## 7 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,460 x ~## 8 <split [666/25~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,590 x ~## 9 <split [666/23~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,340 x ~## 10 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,490 x ~## # ... with 15 more rows
看看结果如何
collect_metrics(ranger_tune)## # A tibble: 20 x 8## mtry min_n .metric .estimator mean n std_err .config## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>## 1 2 30 accuracy binary 0.817 25 0.00393 Preprocessor1_Model01## 2 2 30 roc_auc binary 0.857 25 0.00468 Preprocessor1_Model01## 3 6 3 accuracy binary 0.806 25 0.00408 Preprocessor1_Model02## 4 6 3 roc_auc binary 0.845 25 0.00510 Preprocessor1_Model02## 5 2 39 accuracy binary 0.816 25 0.00383 Preprocessor1_Model03## 6 2 39 roc_auc binary 0.857 25 0.00462 Preprocessor1_Model03## 7 3 27 accuracy binary 0.818 25 0.00386 Preprocessor1_Model04## 8 3 27 roc_auc binary 0.855 25 0.00478 Preprocessor1_Model04## 9 6 36 accuracy binary 0.813 25 0.00424 Preprocessor1_Model05## 10 6 36 roc_auc binary 0.852 25 0.00469 Preprocessor1_Model05## 11 3 20 accuracy binary 0.822 25 0.00434 Preprocessor1_Model06## 12 3 20 roc_auc binary 0.854 25 0.00488 Preprocessor1_Model06## 13 4 24 accuracy binary 0.817 25 0.00428 Preprocessor1_Model07## 14 4 24 roc_auc binary 0.853 25 0.00475 Preprocessor1_Model07## 15 6 12 accuracy binary 0.818 25 0.00366 Preprocessor1_Model08## 16 6 12 roc_auc binary 0.850 25 0.00496 Preprocessor1_Model08## 17 5 15 accuracy binary 0.819 25 0.00425 Preprocessor1_Model09## 18 5 15 roc_auc binary 0.852 25 0.00483 Preprocessor1_Model09## 19 1 7 accuracy binary 0.812 25 0.00447 Preprocessor1_Model10## 20 1 7 roc_auc binary 0.856 25 0.00437 Preprocessor1_Model10autoplot(ranger_tune)

根据结果选择最合适的参数
ranger_grid <- grid_regular(mtry(range = c(1,3)),min_n(range = c(5,10)),levels = 5)ranger_grid## # A tibble: 15 x 2## mtry min_n## <int> <int>## 1 1 5## 2 2 5## 3 3 5## 4 1 6## 5 2 6## 6 3 6## 7 1 7## 8 2 7## 9 3 7## 10 1 8## 11 2 8## 12 3 8## 13 1 10## 14 2 10## 15 3 10
再训练一次
set.seed(1234)doParallel::registerDoParallel()regular_res <- tune_grid(ranger_wf,titanic_boot,grid = ranger_grid,control = control_grid(verbose = TRUE,save_pred = TRUE,parallel_over = 'everything'))regular_res## # Tuning results## # Bootstrap sampling## # A tibble: 25 x 5## splits id .metrics .notes .predictions## <list> <chr> <list> <list> <list>## 1 <split [666/25~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,750 x ~## 2 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,720 x ~## 3 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,705 x ~## 4 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,735 x ~## 5 <split [666/26~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,900 x ~## 6 <split [666/23~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,510 x ~## 7 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,690 x ~## 8 <split [666/25~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,885 x ~## 9 <split [666/23~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,510 x ~## 10 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,735 x ~## # ... with 15 more rows
再看看结果
collect_metrics(regular_res)## # A tibble: 30 x 8## mtry min_n .metric .estimator mean n std_err .config## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>## 1 1 5 accuracy binary 0.810 25 0.00406 Preprocessor1_Model01## 2 1 5 roc_auc binary 0.856 25 0.00437 Preprocessor1_Model01## 3 2 5 accuracy binary 0.823 25 0.00415 Preprocessor1_Model02## 4 2 5 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model02## 5 3 5 accuracy binary 0.818 25 0.00440 Preprocessor1_Model03## 6 3 5 roc_auc binary 0.851 25 0.00504 Preprocessor1_Model03## 7 1 6 accuracy binary 0.811 25 0.00408 Preprocessor1_Model04## 8 1 6 roc_auc binary 0.856 25 0.00435 Preprocessor1_Model04## 9 2 6 accuracy binary 0.824 25 0.00416 Preprocessor1_Model05## 10 2 6 roc_auc binary 0.856 25 0.00489 Preprocessor1_Model05## # ... with 20 more rowsautoplot(regular_res)

show_best(regular_res,metric = 'roc_auc')## # A tibble: 5 x 8## mtry min_n .metric .estimator mean n std_err .config## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>## 1 1 5 roc_auc binary 0.856 25 0.00437 Preprocessor1_Model01## 2 2 6 roc_auc binary 0.856 25 0.00489 Preprocessor1_Model05## 3 2 8 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model11## 4 2 5 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model02## 5 2 10 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model14select_best(regular_res,metric = 'roc_auc')## # A tibble: 1 x 3## mtry min_n .config## <int> <int> <chr>## 1 1 5 Preprocessor1_Model01
看看AUC曲线
collect_predictions(regular_res) %>% roc_auc(Survived,.pred_0)## # A tibble: 1 x 3## .metric .estimator .estimate## <chr> <chr> <dbl>## 1 roc_auc binary 0.848collect_predictions(regular_res) %>% roc_curve(Survived,.pred_0) %>%ggplot(aes(x = 1-specificity, y=sensitivity))+geom_line(size=1.5)+geom_abline(lty=2, size=1.2,color='gray80',)+theme_minimal()

结果评价
final_wf <- finalize_workflow(ranger_wf,select_best(regular_res,'roc_auc'))final_wf## == Workflow =======================================================================## Preprocessor: Recipe## Model: rand_forest()#### -- Preprocessor -------------------------------------------------------------------## 1 Recipe Step#### * step_impute_mean()#### -- Model --------------------------------------------------------------------------## Random Forest Model Specification (classification)#### Main Arguments:## mtry = 1## trees = 1000## min_n = 5#### Computational engine: rangerfinal_res <- last_fit(final_wf, titanic_split)collect_metrics(final_res)## # A tibble: 2 x 4## .metric .estimator .estimate .config## <chr> <chr> <dbl> <chr>## 1 accuracy binary 0.825 Preprocessor1_Model1## 2 roc_auc binary 0.884 Preprocessor1_Model1collect_predictions(final_res) %>% roc_curve(Survived,.pred_0) %>% autoplot()

collect_predictions(final_res) %>% roc_auc(Survived, .pred_0)## # A tibble: 1 x 3## .metric .estimator .estimate## <chr> <chr> <dbl>## 1 roc_auc binary 0.884collect_predictions(final_res) %>% roc_curve(Survived,.pred_0) %>%ggplot(aes(x = 1-specificity, y=sensitivity))+geom_line(size=1.5)+geom_abline(lty=2, size=1.2,color='gray80',)+theme_minimal()

预测新数据
test <- read_csv('../datasets/titanic.test.csv')#### -- Column specification -----------------------------------------------------------## cols(## .default = col_character(),## PassengerId = col_double(),## Pclass = col_double(),## Age = col_double(),## SibSp = col_double(),## Parch = col_double(),## Fare = col_double(),## WikiId = col_double(),## Age_wiki = col_double(),## Class = col_double()## )## i Use `spec()` for the full column specifications.test## # A tibble: 418 x 20## PassengerId Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked## <dbl> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <chr> <dbl> <chr> <chr>## 1 892 3 Kelly~ male 34.5 0 0 330911 7.83 <NA> Q## 2 893 3 Wilke~ fema~ 47 1 0 363272 7 <NA> S## 3 894 2 Myles~ male 62 0 0 240276 9.69 <NA> Q## 4 895 3 Wirz,~ male 27 0 0 315154 8.66 <NA> S## 5 896 3 Hirvo~ fema~ 22 1 1 31012~ 12.3 <NA> S## 6 897 3 Svens~ male 14 0 0 7538 9.22 <NA> S## 7 898 3 Conno~ fema~ 30 0 0 330972 7.63 <NA> Q## 8 899 2 Caldw~ male 26 1 1 248738 29 <NA> S## 9 900 3 Abrah~ fema~ 18 0 0 2657 7.23 <NA> C## 10 901 3 Davie~ male 21 2 0 A/4 4~ 24.2 <NA> S## # ... with 408 more rows, and 9 more variables: WikiId <dbl>, Name_wiki <chr>,## # Age_wiki <dbl>, Hometown <chr>, Boarded <chr>, Destination <chr>,## # Lifeboat <chr>, Body <chr>, Class <dbl>test_df <- test %>%select(-c(12:20),-Name,-Ticket,-Cabin,-PassengerId) %>%drop_na() %>%mutate_if(is.character, factor)test_res <- predict(final_res$.workflow[[1]],new_data = test_df)test_res## # A tibble: 331 x 1## .pred_class## <fct>## 1 0## 2 0## 3 0## 4 0## 5 0## 6 0## 7 0## 8 0## 9 1## 10 0## # ... with 321 more rows
得到的这个结果可以提交到kaggle上面。
