数据探索
library(tidyverse)## Warning: package 'tidyverse' was built under R version 4.0.5## -- Attaching packages ----------------------------------------------------------- tidyverse 1.3.1 --## v ggplot2 3.3.3 v purrr 0.3.4## v tibble 3.1.1 v dplyr 1.0.6## v tidyr 1.1.3 v stringr 1.4.0## v readr 1.4.0 v forcats 0.5.1## Warning: package 'ggplot2' was built under R version 4.0.5## Warning: package 'tibble' was built under R version 4.0.5## Warning: package 'tidyr' was built under R version 4.0.5## Warning: package 'dplyr' was built under R version 4.0.5## Warning: package 'forcats' was built under R version 4.0.5## -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --## x dplyr::filter() masks stats::filter()## x dplyr::lag() masks stats::lag()food_consumption <- readr::read_csv("../datasets/tidytuesday/data/2020/2020-02-18/food_consumption.csv")## ## -- Column specification ----------------------------------------------------------------------------## cols(## country = col_character(),## food_category = col_character(),## consumption = col_double(),## co2_emmission = col_double()## )food_consumption## # A tibble: 1,430 x 4## country food_category consumption co2_emmission## <chr> <chr> <dbl> <dbl>## 1 Argentina Pork 10.5 37.2 ## 2 Argentina Poultry 38.7 41.5 ## 3 Argentina Beef 55.5 1712 ## 4 Argentina Lamb & Goat 1.56 54.6 ## 5 Argentina Fish 4.36 6.96## 6 Argentina Eggs 11.4 10.5 ## 7 Argentina Milk - inc. cheese 195. 278. ## 8 Argentina Wheat and Wheat Products 103. 19.7 ## 9 Argentina Rice 8.77 11.2 ## 10 Argentina Soybeans 0 0 ## # ... with 1,420 more rows
library(countrycode)## Warning: package 'countrycode' was built under R version 4.0.5library(janitor)## Warning: package 'janitor' was built under R version 4.0.5## ## Attaching package: 'janitor'## The following objects are masked from 'package:stats':## ## chisq.test, fisher.testfood <- food_consumption %>% select(-co2_emmission) %>% pivot_wider( names_from = food_category, values_from = consumption ) %>% clean_names() %>% mutate(continent = countrycode( country, origin = "country.name", destination = "continent" )) %>% mutate(asia = case_when( continent == "Asia" ~ "Asia", TRUE ~ "Other" )) %>% select(-country, -continent) %>% mutate_if(is.character, factor)## Warning in FUN(X[[i]], ...): strings not representable in native encoding will## be translated to UTF-8## Warning in FUN(X[[i]], ...): unable to translate '<U+00C4>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00D6>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00E4>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00F6>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00DF>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00C6>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00E6>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00D8>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00F8>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00C5>' to native encoding## Warning in FUN(X[[i]], ...): unable to translate '<U+00E5>' to native encodingfood## # A tibble: 130 x 12## pork poultry beef lamb_goat fish eggs milk_inc_cheese wheat_and_wheat_pr~## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>## 1 10.5 38.7 55.5 1.56 4.36 11.4 195. 103. ## 2 24.1 46.1 33.9 9.87 17.7 8.51 234. 70.5## 3 10.9 13.2 22.5 15.3 3.85 12.5 304. 139. ## 4 21.7 26.9 13.4 21.1 74.4 8.24 226. 72.9## 5 22.3 35.0 22.5 18.9 20.4 9.91 137. 76.9## 6 27.6 50.0 36.2 0.43 12.4 14.6 255. 80.4## 7 16.8 27.4 29.1 8.23 6.53 13.1 211. 109. ## 8 43.6 21.4 29.9 1.67 23.1 14.6 255. 103. ## 9 12.6 45 39.2 0.62 10.0 8.98 149. 53 ## 10 10.4 18.4 23.4 9.56 5.21 8.29 288. 92.3## # ... with 120 more rows, and 4 more variables: rice <dbl>, soybeans <dbl>,## # nuts_inc_peanut_butter <dbl>, asia <fct>
library(GGally)## Warning: package 'GGally' was built under R version 4.0.5## Registered S3 method overwritten by 'GGally':## method from ## +.gg ggplot2ggscatmat(food, columns = 1:11, color = "asia", alpha = 0.7)
建立模型
library(tidymodels)## Warning: package 'tidymodels' was built under R version 4.0.5## -- Attaching packages ---------------------------------------------------------- tidymodels 0.1.3 --## v broom 0.7.6 v rsample 0.0.9 ## v dials 0.0.9 v tune 0.1.5 ## v infer 0.5.4 v workflows 0.2.2 ## v modeldata 0.1.0 v workflowsets 0.0.2 ## v parsnip 0.1.5 v yardstick 0.0.8 ## v recipes 0.1.16## Warning: package 'broom' was built under R version 4.0.5## Warning: package 'scales' was built under R version 4.0.5## Warning: package 'recipes' was built under R version 4.0.5## Warning: package 'rsample' was built under R version 4.0.5## Warning: package 'tune' was built under R version 4.0.5## Warning: package 'workflows' was built under R version 4.0.5## Warning: package 'workflowsets' was built under R version 4.0.5## Warning: package 'yardstick' was built under R version 4.0.5## -- Conflicts ------------------------------------------------------------- tidymodels_conflicts() --## x scales::discard() masks purrr::discard()## x dplyr::filter() masks stats::filter()## x recipes::fixed() masks stringr::fixed()## x dplyr::lag() masks stats::lag()## x yardstick::spec() masks readr::spec()## x recipes::step() masks stats::step()## * Use tidymodels_prefer() to resolve common conflicts.set.seed(1234)food_boot <- bootstraps(food, times = 30)food_boot## # Bootstrap sampling ## # A tibble: 30 x 2## splits id ## <list> <chr> ## 1 <split [130/48]> Bootstrap01## 2 <split [130/49]> Bootstrap02## 3 <split [130/49]> Bootstrap03## 4 <split [130/51]> Bootstrap04## 5 <split [130/47]> Bootstrap05## 6 <split [130/51]> Bootstrap06## 7 <split [130/57]> Bootstrap07## 8 <split [130/51]> Bootstrap08## 9 <split [130/44]> Bootstrap09## 10 <split [130/53]> Bootstrap10## # ... with 20 more rows
rf_spec <- rand_forest( mode = "classification", mtry = tune(), trees = 1000, min_n = tune()) %>% set_engine("ranger")rf_spec## Random Forest Model Specification (classification)## ## Main Arguments:## mtry = tune()## trees = 1000## min_n = tune()## ## Computational engine: ranger
调参
doParallel::registerDoParallel()rf_grid <- tune_grid( rf_spec, asia ~ ., resamples = food_boot)## i Creating pre-processing data to finalize unknown parameter: mtryrf_grid## # Tuning results## # Bootstrap sampling ## # A tibble: 30 x 4## splits id .metrics .notes ## <list> <chr> <list> <list> ## 1 <split [130/48]> Bootstrap01 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## 2 <split [130/49]> Bootstrap02 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## 3 <split [130/49]> Bootstrap03 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## 4 <split [130/51]> Bootstrap04 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## 5 <split [130/47]> Bootstrap05 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## 6 <split [130/51]> Bootstrap06 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## 7 <split [130/57]> Bootstrap07 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## 8 <split [130/51]> Bootstrap08 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## 9 <split [130/44]> Bootstrap09 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## 10 <split [130/53]> Bootstrap10 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>## # ... with 20 more rows
评价模型
rf_grid %>% collect_metrics()## # A tibble: 20 x 8## mtry min_n .metric .estimator mean n std_err .config ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 11 15 accuracy binary 0.812 30 0.0113 Preprocessor1_Model01## 2 11 15 roc_auc binary 0.823 30 0.0106 Preprocessor1_Model01## 3 4 33 accuracy binary 0.813 30 0.00910 Preprocessor1_Model02## 4 4 33 roc_auc binary 0.821 30 0.00995 Preprocessor1_Model02## 5 5 31 accuracy binary 0.816 30 0.00837 Preprocessor1_Model03## 6 5 31 roc_auc binary 0.820 30 0.0103 Preprocessor1_Model03## 7 4 37 accuracy binary 0.817 30 0.00863 Preprocessor1_Model04## 8 4 37 roc_auc binary 0.819 30 0.0105 Preprocessor1_Model04## 9 6 9 accuracy binary 0.825 30 0.00908 Preprocessor1_Model05## 10 6 9 roc_auc binary 0.833 30 0.00922 Preprocessor1_Model05## 11 2 4 accuracy binary 0.830 30 0.00816 Preprocessor1_Model06## 12 2 4 roc_auc binary 0.844 30 0.00975 Preprocessor1_Model06## 13 2 12 accuracy binary 0.830 30 0.00774 Preprocessor1_Model07## 14 2 12 roc_auc binary 0.836 30 0.00916 Preprocessor1_Model07## 15 7 21 accuracy binary 0.816 30 0.00911 Preprocessor1_Model08## 16 7 21 roc_auc binary 0.824 30 0.0102 Preprocessor1_Model08## 17 8 18 accuracy binary 0.815 30 0.0102 Preprocessor1_Model09## 18 8 18 roc_auc binary 0.825 30 0.0102 Preprocessor1_Model09## 19 9 26 accuracy binary 0.813 30 0.00961 Preprocessor1_Model10## 20 9 26 roc_auc binary 0.821 30 0.0108 Preprocessor1_Model10rf_grid %>% show_best("roc_auc")## # A tibble: 5 x 8## mtry min_n .metric .estimator mean n std_err .config ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 2 4 roc_auc binary 0.844 30 0.00975 Preprocessor1_Model06## 2 2 12 roc_auc binary 0.836 30 0.00916 Preprocessor1_Model07## 3 6 9 roc_auc binary 0.833 30 0.00922 Preprocessor1_Model05## 4 8 18 roc_auc binary 0.825 30 0.0102 Preprocessor1_Model09## 5 7 21 roc_auc binary 0.824 30 0.0102 Preprocessor1_Model08