数据探索
library(tidyverse)
## Warning: package 'tidyverse' was built under R version 4.0.5
## -- Attaching packages ----------------------------------------------------------- tidyverse 1.3.1 --
## v ggplot2 3.3.3 v purrr 0.3.4
## v tibble 3.1.1 v dplyr 1.0.6
## v tidyr 1.1.3 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.1
## Warning: package 'ggplot2' was built under R version 4.0.5
## Warning: package 'tibble' was built under R version 4.0.5
## Warning: package 'tidyr' was built under R version 4.0.5
## Warning: package 'dplyr' was built under R version 4.0.5
## Warning: package 'forcats' was built under R version 4.0.5
## -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
food_consumption <- readr::read_csv("../datasets/tidytuesday/data/2020/2020-02-18/food_consumption.csv")
##
## -- Column specification ----------------------------------------------------------------------------
## cols(
## country = col_character(),
## food_category = col_character(),
## consumption = col_double(),
## co2_emmission = col_double()
## )
food_consumption
## # A tibble: 1,430 x 4
## country food_category consumption co2_emmission
## <chr> <chr> <dbl> <dbl>
## 1 Argentina Pork 10.5 37.2
## 2 Argentina Poultry 38.7 41.5
## 3 Argentina Beef 55.5 1712
## 4 Argentina Lamb & Goat 1.56 54.6
## 5 Argentina Fish 4.36 6.96
## 6 Argentina Eggs 11.4 10.5
## 7 Argentina Milk - inc. cheese 195. 278.
## 8 Argentina Wheat and Wheat Products 103. 19.7
## 9 Argentina Rice 8.77 11.2
## 10 Argentina Soybeans 0 0
## # ... with 1,420 more rows
library(countrycode)
## Warning: package 'countrycode' was built under R version 4.0.5
library(janitor)
## Warning: package 'janitor' was built under R version 4.0.5
##
## Attaching package: 'janitor'
## The following objects are masked from 'package:stats':
##
## chisq.test, fisher.test
food <- food_consumption %>%
select(-co2_emmission) %>%
pivot_wider(
names_from = food_category,
values_from = consumption
) %>%
clean_names() %>%
mutate(continent = countrycode(
country,
origin = "country.name",
destination = "continent"
)) %>%
mutate(asia = case_when(
continent == "Asia" ~ "Asia",
TRUE ~ "Other"
)) %>%
select(-country, -continent) %>%
mutate_if(is.character, factor)
## Warning in FUN(X[[i]], ...): strings not representable in native encoding will
## be translated to UTF-8
## Warning in FUN(X[[i]], ...): unable to translate '<U+00C4>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00D6>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00E4>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00F6>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00DF>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00C6>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00E6>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00D8>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00F8>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00C5>' to native encoding
## Warning in FUN(X[[i]], ...): unable to translate '<U+00E5>' to native encoding
food
## # A tibble: 130 x 12
## pork poultry beef lamb_goat fish eggs milk_inc_cheese wheat_and_wheat_pr~
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 10.5 38.7 55.5 1.56 4.36 11.4 195. 103.
## 2 24.1 46.1 33.9 9.87 17.7 8.51 234. 70.5
## 3 10.9 13.2 22.5 15.3 3.85 12.5 304. 139.
## 4 21.7 26.9 13.4 21.1 74.4 8.24 226. 72.9
## 5 22.3 35.0 22.5 18.9 20.4 9.91 137. 76.9
## 6 27.6 50.0 36.2 0.43 12.4 14.6 255. 80.4
## 7 16.8 27.4 29.1 8.23 6.53 13.1 211. 109.
## 8 43.6 21.4 29.9 1.67 23.1 14.6 255. 103.
## 9 12.6 45 39.2 0.62 10.0 8.98 149. 53
## 10 10.4 18.4 23.4 9.56 5.21 8.29 288. 92.3
## # ... with 120 more rows, and 4 more variables: rice <dbl>, soybeans <dbl>,
## # nuts_inc_peanut_butter <dbl>, asia <fct>
library(GGally)
## Warning: package 'GGally' was built under R version 4.0.5
## Registered S3 method overwritten by 'GGally':
## method from
## +.gg ggplot2
ggscatmat(food, columns = 1:11, color = "asia", alpha = 0.7)
建立模型
library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.0.5
## -- Attaching packages ---------------------------------------------------------- tidymodels 0.1.3 --
## v broom 0.7.6 v rsample 0.0.9
## v dials 0.0.9 v tune 0.1.5
## v infer 0.5.4 v workflows 0.2.2
## v modeldata 0.1.0 v workflowsets 0.0.2
## v parsnip 0.1.5 v yardstick 0.0.8
## v recipes 0.1.16
## Warning: package 'broom' was built under R version 4.0.5
## Warning: package 'scales' was built under R version 4.0.5
## Warning: package 'recipes' was built under R version 4.0.5
## Warning: package 'rsample' was built under R version 4.0.5
## Warning: package 'tune' was built under R version 4.0.5
## Warning: package 'workflows' was built under R version 4.0.5
## Warning: package 'workflowsets' was built under R version 4.0.5
## Warning: package 'yardstick' was built under R version 4.0.5
## -- Conflicts ------------------------------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter() masks stats::filter()
## x recipes::fixed() masks stringr::fixed()
## x dplyr::lag() masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step() masks stats::step()
## * Use tidymodels_prefer() to resolve common conflicts.
set.seed(1234)
food_boot <- bootstraps(food, times = 30)
food_boot
## # Bootstrap sampling
## # A tibble: 30 x 2
## splits id
## <list> <chr>
## 1 <split [130/48]> Bootstrap01
## 2 <split [130/49]> Bootstrap02
## 3 <split [130/49]> Bootstrap03
## 4 <split [130/51]> Bootstrap04
## 5 <split [130/47]> Bootstrap05
## 6 <split [130/51]> Bootstrap06
## 7 <split [130/57]> Bootstrap07
## 8 <split [130/51]> Bootstrap08
## 9 <split [130/44]> Bootstrap09
## 10 <split [130/53]> Bootstrap10
## # ... with 20 more rows
rf_spec <- rand_forest(
mode = "classification",
mtry = tune(),
trees = 1000,
min_n = tune()
) %>%
set_engine("ranger")
rf_spec
## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = tune()
## trees = 1000
## min_n = tune()
##
## Computational engine: ranger
调参
doParallel::registerDoParallel()
rf_grid <- tune_grid(
rf_spec,
asia ~ .,
resamples = food_boot
)
## i Creating pre-processing data to finalize unknown parameter: mtry
rf_grid
## # Tuning results
## # Bootstrap sampling
## # A tibble: 30 x 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [130/48]> Bootstrap01 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## 2 <split [130/49]> Bootstrap02 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## 3 <split [130/49]> Bootstrap03 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## 4 <split [130/51]> Bootstrap04 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## 5 <split [130/47]> Bootstrap05 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## 6 <split [130/51]> Bootstrap06 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## 7 <split [130/57]> Bootstrap07 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## 8 <split [130/51]> Bootstrap08 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## 9 <split [130/44]> Bootstrap09 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## 10 <split [130/53]> Bootstrap10 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
## # ... with 20 more rows
评价模型
rf_grid %>% collect_metrics()
## # A tibble: 20 x 8
## mtry min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 11 15 accuracy binary 0.812 30 0.0113 Preprocessor1_Model01
## 2 11 15 roc_auc binary 0.823 30 0.0106 Preprocessor1_Model01
## 3 4 33 accuracy binary 0.813 30 0.00910 Preprocessor1_Model02
## 4 4 33 roc_auc binary 0.821 30 0.00995 Preprocessor1_Model02
## 5 5 31 accuracy binary 0.816 30 0.00837 Preprocessor1_Model03
## 6 5 31 roc_auc binary 0.820 30 0.0103 Preprocessor1_Model03
## 7 4 37 accuracy binary 0.817 30 0.00863 Preprocessor1_Model04
## 8 4 37 roc_auc binary 0.819 30 0.0105 Preprocessor1_Model04
## 9 6 9 accuracy binary 0.825 30 0.00908 Preprocessor1_Model05
## 10 6 9 roc_auc binary 0.833 30 0.00922 Preprocessor1_Model05
## 11 2 4 accuracy binary 0.830 30 0.00816 Preprocessor1_Model06
## 12 2 4 roc_auc binary 0.844 30 0.00975 Preprocessor1_Model06
## 13 2 12 accuracy binary 0.830 30 0.00774 Preprocessor1_Model07
## 14 2 12 roc_auc binary 0.836 30 0.00916 Preprocessor1_Model07
## 15 7 21 accuracy binary 0.816 30 0.00911 Preprocessor1_Model08
## 16 7 21 roc_auc binary 0.824 30 0.0102 Preprocessor1_Model08
## 17 8 18 accuracy binary 0.815 30 0.0102 Preprocessor1_Model09
## 18 8 18 roc_auc binary 0.825 30 0.0102 Preprocessor1_Model09
## 19 9 26 accuracy binary 0.813 30 0.00961 Preprocessor1_Model10
## 20 9 26 roc_auc binary 0.821 30 0.0108 Preprocessor1_Model10
rf_grid %>% show_best("roc_auc")
## # A tibble: 5 x 8
## mtry min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 2 4 roc_auc binary 0.844 30 0.00975 Preprocessor1_Model06
## 2 2 12 roc_auc binary 0.836 30 0.00916 Preprocessor1_Model07
## 3 6 9 roc_auc binary 0.833 30 0.00922 Preprocessor1_Model05
## 4 8 18 roc_auc binary 0.825 30 0.0102 Preprocessor1_Model09
## 5 7 21 roc_auc binary 0.824 30 0.0102 Preprocessor1_Model08