title: “tidymodels-exercise-08”
author: “liyue”
date: “2021/7/31”
output: html_document
使用tidymodels
预测著名的泰坦尼克号数据集!
数据探索
首先加载数据
rm(list = ls())
library(tidyverse)
## -- Attaching packages ------------------------------------------ tidyverse 1.3.1 --
## v ggplot2 3.3.5 v purrr 0.3.4
## v tibble 3.1.2 v dplyr 1.0.7
## v tidyr 1.1.3 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.1
## -- Conflicts --------------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
titanic <- read_csv('../datasets/titanic.train.csv')
##
## -- Column specification -----------------------------------------------------------
## cols(
## .default = col_character(),
## PassengerId = col_double(),
## Survived = col_double(),
## Pclass = col_double(),
## Age = col_double(),
## SibSp = col_double(),
## Parch = col_double(),
## Fare = col_double(),
## WikiId = col_double(),
## Age_wiki = col_double(),
## Class = col_double()
## )
## i Use `spec()` for the full column specifications.
titanic
## # A tibble: 891 x 21
## PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin
## <dbl> <dbl> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <chr> <dbl> <chr>
## 1 1 0 3 Braun~ male 22 1 0 A/5 2~ 7.25 <NA>
## 2 2 1 1 Cumin~ fema~ 38 1 0 PC 17~ 71.3 C85
## 3 3 1 3 Heikk~ fema~ 26 0 0 STON/~ 7.92 <NA>
## 4 4 1 1 Futre~ fema~ 35 1 0 113803 53.1 C123
## 5 5 0 3 Allen~ male 35 0 0 373450 8.05 <NA>
## 6 6 0 3 Moran~ male NA 0 0 330877 8.46 <NA>
## 7 7 0 1 McCar~ male 54 0 0 17463 51.9 E46
## 8 8 0 3 Palss~ male 2 3 1 349909 21.1 <NA>
## 9 9 1 3 Johns~ fema~ 27 0 2 347742 11.1 <NA>
## 10 10 1 2 Nasse~ fema~ 14 1 0 237736 30.1 <NA>
## # ... with 881 more rows, and 10 more variables: Embarked <chr>, WikiId <dbl>,
## # Name_wiki <chr>, Age_wiki <dbl>, Hometown <chr>, Boarded <chr>,
## # Destination <chr>, Lifeboat <chr>, Body <chr>, Class <dbl>
查看一下数据的情况
skimr::skim(titanic)
Table: Data summary
Name | titanic |
Number of rows | 891 |
Number of columns | 21 |
_ | |
Column type frequency: | |
character | 11 |
numeric | 10 |
__ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
Name | 0 | 1.00 | 12 | 82 | 0 | 891 | 0 |
Sex | 0 | 1.00 | 4 | 6 | 0 | 2 | 0 |
Ticket | 0 | 1.00 | 3 | 18 | 0 | 681 | 0 |
Cabin | 687 | 0.23 | 1 | 15 | 0 | 147 | 0 |
Embarked | 2 | 1.00 | 1 | 1 | 0 | 3 | 0 |
Name_wiki | 2 | 1.00 | 12 | 69 | 0 | 889 | 0 |
Hometown | 2 | 1.00 | 6 | 44 | 0 | 437 | 0 |
Boarded | 2 | 1.00 | 7 | 11 | 0 | 4 | 0 |
Destination | 2 | 1.00 | 10 | 39 | 0 | 234 | 0 |
Lifeboat | 546 | 0.39 | 1 | 3 | 0 | 22 | 0 |
Body | 804 | 0.10 | 3 | 8 | 0 | 87 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
PassengerId | 0 | 1.0 | 446.00 | 257.35 | 1.00 | 223.50 | 446.00 | 668.5 | 891.00 | ▇▇▇▇▇ |
Survived | 0 | 1.0 | 0.38 | 0.49 | 0.00 | 0.00 | 0.00 | 1.0 | 1.00 | ▇▁▁▁▅ |
Pclass | 0 | 1.0 | 2.31 | 0.84 | 1.00 | 2.00 | 3.00 | 3.0 | 3.00 | ▃▁▃▁▇ |
Age | 177 | 0.8 | 29.70 | 14.53 | 0.42 | 20.12 | 28.00 | 38.0 | 80.00 | ▂▇▅▂▁ |
SibSp | 0 | 1.0 | 0.52 | 1.10 | 0.00 | 0.00 | 0.00 | 1.0 | 8.00 | ▇▁▁▁▁ |
Parch | 0 | 1.0 | 0.38 | 0.81 | 0.00 | 0.00 | 0.00 | 0.0 | 6.00 | ▇▁▁▁▁ |
Fare | 0 | 1.0 | 32.20 | 49.69 | 0.00 | 7.91 | 14.45 | 31.0 | 512.33 | ▇▁▁▁▁ |
WikiId | 2 | 1.0 | 665.47 | 380.80 | 1.00 | 336.00 | 672.00 | 996.0 | 1314.00 | ▇▇▇▇▇ |
Age_wiki | 4 | 1.0 | 29.32 | 13.93 | 0.42 | 20.00 | 28.00 | 38.0 | 74.00 | ▂▇▅▂▁ |
Class | 2 | 1.0 | 2.31 | 0.84 | 1.00 | 2.00 | 3.00 | 3.0 | 3.00 | ▃▁▃▁▇ |
可以看到共有891行,21个变量,有很多缺失值,还有很多字符型变量,需要进行预处理。主要从以下几个方面入手:
- 处理缺失值;
- 字符型变量变为因子型;
- 删掉对结果影响不大的变量
下面是代码
titanic_df <- titanic %>%
select(-c(13:21),-Name,-Ticket,-Cabin,-PassengerId) %>%
drop_na(Embarked) %>%
mutate_if(is.character, factor) %>%
mutate(Survived = factor(Survived))
skimr::skim(titanic_df)
Table: Data summary
Name | titanic_df |
Number of rows | 889 |
Number of columns | 8 |
_ | |
Column type frequency: | |
factor | 3 |
numeric | 5 |
__ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
Survived | 0 | 1 | FALSE | 2 | 0: 549, 1: 340 |
Sex | 0 | 1 | FALSE | 2 | mal: 577, fem: 312 |
Embarked | 0 | 1 | FALSE | 3 | S: 644, C: 168, Q: 77 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
Pclass | 0 | 1.0 | 2.31 | 0.83 | 1.00 | 2.0 | 3.00 | 3 | 3.00 | ▃▁▃▁▇ |
Age | 177 | 0.8 | 29.64 | 14.49 | 0.42 | 20.0 | 28.00 | 38 | 80.00 | ▂▇▅▂▁ |
SibSp | 0 | 1.0 | 0.52 | 1.10 | 0.00 | 0.0 | 0.00 | 1 | 8.00 | ▇▁▁▁▁ |
Parch | 0 | 1.0 | 0.38 | 0.81 | 0.00 | 0.0 | 0.00 | 0 | 6.00 | ▇▁▁▁▁ |
Fare | 0 | 1.0 | 32.10 | 49.70 | 0.00 | 7.9 | 14.45 | 31 | 512.33 | ▇▁▁▁▁ |
这样只留下了7个预测变量,1个结果变量,结果变量为因子型,存活或者死亡;预测变量为因子型和数值型,并删掉了很多不太相关的变量(这样并不严谨)。
简单再看下数据情况
# 存活与否与性别的关系
ggplot(titanic_df, aes(x=Sex, fill=Survived))+
geom_bar()+
theme_minimal()+
labs(x='',y='number')
可以看到男性大部分都活下来了,女性大部分都没能活下来。。。
# 存活与否与舱位的关系
ggplot(titanic_df, aes(x=Embarked, fill=Survived))+
geom_bar()+
theme_minimal()+
labs(x='',y='number')
数据建模
library(tidymodels)
## Registered S3 method overwritten by 'tune':
## method from
## required_pkgs.model_spec parsnip
## -- Attaching packages ----------------------------------------- tidymodels 0.1.3 --
## v broom 0.7.8 v rsample 0.1.0
## v dials 0.0.9 v tune 0.1.5
## v infer 0.5.4 v workflows 0.2.2
## v modeldata 0.1.0 v workflowsets 0.0.2
## v parsnip 0.1.6 v yardstick 0.0.8
## v recipes 0.1.16
## -- Conflicts -------------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter() masks stats::filter()
## x recipes::fixed() masks stringr::fixed()
## x dplyr::lag() masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step() masks stats::step()
## * Use tidymodels_prefer() to resolve common conflicts.
tidymodels_prefer()
titanic_split <- initial_split(titanic_df)
titanic_train <- training(titanic_split)
titanic_test <- testing(titanic_split)
titanic_boot <- bootstraps(titanic_train)
#library(usemodels)
#use_ranger(Survived ~ ., data = titanic_train)
ranger_rec <- recipe(Survived ~ ., data = titanic_train) %>%
step_impute_mean(Age)
ranger_spec <- rand_forest(mode = "classification", mtry = tune(),
trees = 1000,
min_n = tune()) %>%
set_engine('ranger')
ranger_wf <- workflow() %>%
add_recipe(ranger_rec) %>%
add_model(ranger_spec)
set.seed(123)
doParallel::registerDoParallel()
ranger_tune <- tune_grid(
ranger_wf,
titanic_boot,
grid = 10,
control = control_grid(verbose = TRUE,save_pred = TRUE,
parallel_over = 'everything')
)
## i Creating pre-processing data to finalize unknown parameter: mtry
ranger_tune
## # Tuning results
## # Bootstrap sampling
## # A tibble: 25 x 5
## splits id .metrics .notes .predictions
## <list> <chr> <list> <list> <list>
## 1 <split [666/25~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,500 x ~
## 2 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,480 x ~
## 3 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,470 x ~
## 4 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,490 x ~
## 5 <split [666/26~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,600 x ~
## 6 <split [666/23~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,340 x ~
## 7 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,460 x ~
## 8 <split [666/25~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,590 x ~
## 9 <split [666/23~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,340 x ~
## 10 <split [666/24~ Bootstrap~ <tibble [20 x 6~ <tibble [0 x ~ <tibble [2,490 x ~
## # ... with 15 more rows
看看结果如何
collect_metrics(ranger_tune)
## # A tibble: 20 x 8
## mtry min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 2 30 accuracy binary 0.817 25 0.00393 Preprocessor1_Model01
## 2 2 30 roc_auc binary 0.857 25 0.00468 Preprocessor1_Model01
## 3 6 3 accuracy binary 0.806 25 0.00408 Preprocessor1_Model02
## 4 6 3 roc_auc binary 0.845 25 0.00510 Preprocessor1_Model02
## 5 2 39 accuracy binary 0.816 25 0.00383 Preprocessor1_Model03
## 6 2 39 roc_auc binary 0.857 25 0.00462 Preprocessor1_Model03
## 7 3 27 accuracy binary 0.818 25 0.00386 Preprocessor1_Model04
## 8 3 27 roc_auc binary 0.855 25 0.00478 Preprocessor1_Model04
## 9 6 36 accuracy binary 0.813 25 0.00424 Preprocessor1_Model05
## 10 6 36 roc_auc binary 0.852 25 0.00469 Preprocessor1_Model05
## 11 3 20 accuracy binary 0.822 25 0.00434 Preprocessor1_Model06
## 12 3 20 roc_auc binary 0.854 25 0.00488 Preprocessor1_Model06
## 13 4 24 accuracy binary 0.817 25 0.00428 Preprocessor1_Model07
## 14 4 24 roc_auc binary 0.853 25 0.00475 Preprocessor1_Model07
## 15 6 12 accuracy binary 0.818 25 0.00366 Preprocessor1_Model08
## 16 6 12 roc_auc binary 0.850 25 0.00496 Preprocessor1_Model08
## 17 5 15 accuracy binary 0.819 25 0.00425 Preprocessor1_Model09
## 18 5 15 roc_auc binary 0.852 25 0.00483 Preprocessor1_Model09
## 19 1 7 accuracy binary 0.812 25 0.00447 Preprocessor1_Model10
## 20 1 7 roc_auc binary 0.856 25 0.00437 Preprocessor1_Model10
autoplot(ranger_tune)
根据结果选择最合适的参数
ranger_grid <- grid_regular(
mtry(range = c(1,3)),
min_n(range = c(5,10)),
levels = 5
)
ranger_grid
## # A tibble: 15 x 2
## mtry min_n
## <int> <int>
## 1 1 5
## 2 2 5
## 3 3 5
## 4 1 6
## 5 2 6
## 6 3 6
## 7 1 7
## 8 2 7
## 9 3 7
## 10 1 8
## 11 2 8
## 12 3 8
## 13 1 10
## 14 2 10
## 15 3 10
再训练一次
set.seed(1234)
doParallel::registerDoParallel()
regular_res <- tune_grid(
ranger_wf,
titanic_boot,
grid = ranger_grid,
control = control_grid(verbose = TRUE,save_pred = TRUE,
parallel_over = 'everything')
)
regular_res
## # Tuning results
## # Bootstrap sampling
## # A tibble: 25 x 5
## splits id .metrics .notes .predictions
## <list> <chr> <list> <list> <list>
## 1 <split [666/25~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,750 x ~
## 2 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,720 x ~
## 3 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,705 x ~
## 4 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,735 x ~
## 5 <split [666/26~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,900 x ~
## 6 <split [666/23~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,510 x ~
## 7 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,690 x ~
## 8 <split [666/25~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,885 x ~
## 9 <split [666/23~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,510 x ~
## 10 <split [666/24~ Bootstrap~ <tibble [30 x 6~ <tibble [0 x ~ <tibble [3,735 x ~
## # ... with 15 more rows
再看看结果
collect_metrics(regular_res)
## # A tibble: 30 x 8
## mtry min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 5 accuracy binary 0.810 25 0.00406 Preprocessor1_Model01
## 2 1 5 roc_auc binary 0.856 25 0.00437 Preprocessor1_Model01
## 3 2 5 accuracy binary 0.823 25 0.00415 Preprocessor1_Model02
## 4 2 5 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model02
## 5 3 5 accuracy binary 0.818 25 0.00440 Preprocessor1_Model03
## 6 3 5 roc_auc binary 0.851 25 0.00504 Preprocessor1_Model03
## 7 1 6 accuracy binary 0.811 25 0.00408 Preprocessor1_Model04
## 8 1 6 roc_auc binary 0.856 25 0.00435 Preprocessor1_Model04
## 9 2 6 accuracy binary 0.824 25 0.00416 Preprocessor1_Model05
## 10 2 6 roc_auc binary 0.856 25 0.00489 Preprocessor1_Model05
## # ... with 20 more rows
autoplot(regular_res)
show_best(regular_res,metric = 'roc_auc')
## # A tibble: 5 x 8
## mtry min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 5 roc_auc binary 0.856 25 0.00437 Preprocessor1_Model01
## 2 2 6 roc_auc binary 0.856 25 0.00489 Preprocessor1_Model05
## 3 2 8 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model11
## 4 2 5 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model02
## 5 2 10 roc_auc binary 0.856 25 0.00491 Preprocessor1_Model14
select_best(regular_res,metric = 'roc_auc')
## # A tibble: 1 x 3
## mtry min_n .config
## <int> <int> <chr>
## 1 1 5 Preprocessor1_Model01
看看AUC曲线
collect_predictions(regular_res) %>% roc_auc(Survived,.pred_0)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 roc_auc binary 0.848
collect_predictions(regular_res) %>% roc_curve(Survived,.pred_0) %>%
ggplot(aes(x = 1-specificity, y=sensitivity))+
geom_line(size=1.5)+
geom_abline(lty=2, size=1.2,color='gray80',)+
theme_minimal()
结果评价
final_wf <- finalize_workflow(ranger_wf,select_best(regular_res,'roc_auc'))
final_wf
## == Workflow =======================================================================
## Preprocessor: Recipe
## Model: rand_forest()
##
## -- Preprocessor -------------------------------------------------------------------
## 1 Recipe Step
##
## * step_impute_mean()
##
## -- Model --------------------------------------------------------------------------
## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = 1
## trees = 1000
## min_n = 5
##
## Computational engine: ranger
final_res <- last_fit(final_wf, titanic_split)
collect_metrics(final_res)
## # A tibble: 2 x 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.825 Preprocessor1_Model1
## 2 roc_auc binary 0.884 Preprocessor1_Model1
collect_predictions(final_res) %>% roc_curve(Survived,.pred_0) %>% autoplot()
collect_predictions(final_res) %>% roc_auc(Survived, .pred_0)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 roc_auc binary 0.884
collect_predictions(final_res) %>% roc_curve(Survived,.pred_0) %>%
ggplot(aes(x = 1-specificity, y=sensitivity))+
geom_line(size=1.5)+
geom_abline(lty=2, size=1.2,color='gray80',)+
theme_minimal()
预测新数据
test <- read_csv('../datasets/titanic.test.csv')
##
## -- Column specification -----------------------------------------------------------
## cols(
## .default = col_character(),
## PassengerId = col_double(),
## Pclass = col_double(),
## Age = col_double(),
## SibSp = col_double(),
## Parch = col_double(),
## Fare = col_double(),
## WikiId = col_double(),
## Age_wiki = col_double(),
## Class = col_double()
## )
## i Use `spec()` for the full column specifications.
test
## # A tibble: 418 x 20
## PassengerId Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
## <dbl> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <chr> <dbl> <chr> <chr>
## 1 892 3 Kelly~ male 34.5 0 0 330911 7.83 <NA> Q
## 2 893 3 Wilke~ fema~ 47 1 0 363272 7 <NA> S
## 3 894 2 Myles~ male 62 0 0 240276 9.69 <NA> Q
## 4 895 3 Wirz,~ male 27 0 0 315154 8.66 <NA> S
## 5 896 3 Hirvo~ fema~ 22 1 1 31012~ 12.3 <NA> S
## 6 897 3 Svens~ male 14 0 0 7538 9.22 <NA> S
## 7 898 3 Conno~ fema~ 30 0 0 330972 7.63 <NA> Q
## 8 899 2 Caldw~ male 26 1 1 248738 29 <NA> S
## 9 900 3 Abrah~ fema~ 18 0 0 2657 7.23 <NA> C
## 10 901 3 Davie~ male 21 2 0 A/4 4~ 24.2 <NA> S
## # ... with 408 more rows, and 9 more variables: WikiId <dbl>, Name_wiki <chr>,
## # Age_wiki <dbl>, Hometown <chr>, Boarded <chr>, Destination <chr>,
## # Lifeboat <chr>, Body <chr>, Class <dbl>
test_df <- test %>%
select(-c(12:20),-Name,-Ticket,-Cabin,-PassengerId) %>%
drop_na() %>%
mutate_if(is.character, factor)
test_res <- predict(final_res$.workflow[[1]],new_data = test_df)
test_res
## # A tibble: 331 x 1
## .pred_class
## <fct>
## 1 0
## 2 0
## 3 0
## 4 0
## 5 0
## 6 0
## 7 0
## 8 0
## 9 1
## 10 0
## # ... with 321 more rows
得到的这个结果可以提交到kaggle上面。