数据探索

  1. library(tidyverse)
  2. ## Warning: package 'tidyverse' was built under R version 4.0.5
  3. ## -- Attaching packages ----------------------------------------------------------- tidyverse 1.3.1 --
  4. ## v ggplot2 3.3.3 v purrr 0.3.4
  5. ## v tibble 3.1.1 v dplyr 1.0.6
  6. ## v tidyr 1.1.3 v stringr 1.4.0
  7. ## v readr 1.4.0 v forcats 0.5.1
  8. ## Warning: package 'ggplot2' was built under R version 4.0.5
  9. ## Warning: package 'tibble' was built under R version 4.0.5
  10. ## Warning: package 'tidyr' was built under R version 4.0.5
  11. ## Warning: package 'dplyr' was built under R version 4.0.5
  12. ## Warning: package 'forcats' was built under R version 4.0.5
  13. ## -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --
  14. ## x dplyr::filter() masks stats::filter()
  15. ## x dplyr::lag() masks stats::lag()
  16. food_consumption <- readr::read_csv("../datasets/tidytuesday/data/2020/2020-02-18/food_consumption.csv")
  17. ##
  18. ## -- Column specification ----------------------------------------------------------------------------
  19. ## cols(
  20. ## country = col_character(),
  21. ## food_category = col_character(),
  22. ## consumption = col_double(),
  23. ## co2_emmission = col_double()
  24. ## )
  25. food_consumption
  26. ## # A tibble: 1,430 x 4
  27. ## country food_category consumption co2_emmission
  28. ## <chr> <chr> <dbl> <dbl>
  29. ## 1 Argentina Pork 10.5 37.2
  30. ## 2 Argentina Poultry 38.7 41.5
  31. ## 3 Argentina Beef 55.5 1712
  32. ## 4 Argentina Lamb & Goat 1.56 54.6
  33. ## 5 Argentina Fish 4.36 6.96
  34. ## 6 Argentina Eggs 11.4 10.5
  35. ## 7 Argentina Milk - inc. cheese 195. 278.
  36. ## 8 Argentina Wheat and Wheat Products 103. 19.7
  37. ## 9 Argentina Rice 8.77 11.2
  38. ## 10 Argentina Soybeans 0 0
  39. ## # ... with 1,420 more rows
  1. library(countrycode)
  2. ## Warning: package 'countrycode' was built under R version 4.0.5
  3. library(janitor)
  4. ## Warning: package 'janitor' was built under R version 4.0.5
  5. ##
  6. ## Attaching package: 'janitor'
  7. ## The following objects are masked from 'package:stats':
  8. ##
  9. ## chisq.test, fisher.test
  10. food <- food_consumption %>%
  11. select(-co2_emmission) %>%
  12. pivot_wider(
  13. names_from = food_category,
  14. values_from = consumption
  15. ) %>%
  16. clean_names() %>%
  17. mutate(continent = countrycode(
  18. country,
  19. origin = "country.name",
  20. destination = "continent"
  21. )) %>%
  22. mutate(asia = case_when(
  23. continent == "Asia" ~ "Asia",
  24. TRUE ~ "Other"
  25. )) %>%
  26. select(-country, -continent) %>%
  27. mutate_if(is.character, factor)
  28. ## Warning in FUN(X[[i]], ...): strings not representable in native encoding will
  29. ## be translated to UTF-8
  30. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00C4>' to native encoding
  31. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00D6>' to native encoding
  32. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00E4>' to native encoding
  33. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00F6>' to native encoding
  34. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00DF>' to native encoding
  35. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00C6>' to native encoding
  36. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00E6>' to native encoding
  37. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00D8>' to native encoding
  38. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00F8>' to native encoding
  39. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00C5>' to native encoding
  40. ## Warning in FUN(X[[i]], ...): unable to translate '<U+00E5>' to native encoding
  41. food
  42. ## # A tibble: 130 x 12
  43. ## pork poultry beef lamb_goat fish eggs milk_inc_cheese wheat_and_wheat_pr~
  44. ## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
  45. ## 1 10.5 38.7 55.5 1.56 4.36 11.4 195. 103.
  46. ## 2 24.1 46.1 33.9 9.87 17.7 8.51 234. 70.5
  47. ## 3 10.9 13.2 22.5 15.3 3.85 12.5 304. 139.
  48. ## 4 21.7 26.9 13.4 21.1 74.4 8.24 226. 72.9
  49. ## 5 22.3 35.0 22.5 18.9 20.4 9.91 137. 76.9
  50. ## 6 27.6 50.0 36.2 0.43 12.4 14.6 255. 80.4
  51. ## 7 16.8 27.4 29.1 8.23 6.53 13.1 211. 109.
  52. ## 8 43.6 21.4 29.9 1.67 23.1 14.6 255. 103.
  53. ## 9 12.6 45 39.2 0.62 10.0 8.98 149. 53
  54. ## 10 10.4 18.4 23.4 9.56 5.21 8.29 288. 92.3
  55. ## # ... with 120 more rows, and 4 more variables: rice <dbl>, soybeans <dbl>,
  56. ## # nuts_inc_peanut_butter <dbl>, asia <fct>
  1. library(GGally)
  2. ## Warning: package 'GGally' was built under R version 4.0.5
  3. ## Registered S3 method overwritten by 'GGally':
  4. ## method from
  5. ## +.gg ggplot2
  6. ggscatmat(food, columns = 1:11, color = "asia", alpha = 0.7)

tidymodels-exercise-03 - 图1

建立模型

  1. library(tidymodels)
  2. ## Warning: package 'tidymodels' was built under R version 4.0.5
  3. ## -- Attaching packages ---------------------------------------------------------- tidymodels 0.1.3 --
  4. ## v broom 0.7.6 v rsample 0.0.9
  5. ## v dials 0.0.9 v tune 0.1.5
  6. ## v infer 0.5.4 v workflows 0.2.2
  7. ## v modeldata 0.1.0 v workflowsets 0.0.2
  8. ## v parsnip 0.1.5 v yardstick 0.0.8
  9. ## v recipes 0.1.16
  10. ## Warning: package 'broom' was built under R version 4.0.5
  11. ## Warning: package 'scales' was built under R version 4.0.5
  12. ## Warning: package 'recipes' was built under R version 4.0.5
  13. ## Warning: package 'rsample' was built under R version 4.0.5
  14. ## Warning: package 'tune' was built under R version 4.0.5
  15. ## Warning: package 'workflows' was built under R version 4.0.5
  16. ## Warning: package 'workflowsets' was built under R version 4.0.5
  17. ## Warning: package 'yardstick' was built under R version 4.0.5
  18. ## -- Conflicts ------------------------------------------------------------- tidymodels_conflicts() --
  19. ## x scales::discard() masks purrr::discard()
  20. ## x dplyr::filter() masks stats::filter()
  21. ## x recipes::fixed() masks stringr::fixed()
  22. ## x dplyr::lag() masks stats::lag()
  23. ## x yardstick::spec() masks readr::spec()
  24. ## x recipes::step() masks stats::step()
  25. ## * Use tidymodels_prefer() to resolve common conflicts.
  26. set.seed(1234)
  27. food_boot <- bootstraps(food, times = 30)
  28. food_boot
  29. ## # Bootstrap sampling
  30. ## # A tibble: 30 x 2
  31. ## splits id
  32. ## <list> <chr>
  33. ## 1 <split [130/48]> Bootstrap01
  34. ## 2 <split [130/49]> Bootstrap02
  35. ## 3 <split [130/49]> Bootstrap03
  36. ## 4 <split [130/51]> Bootstrap04
  37. ## 5 <split [130/47]> Bootstrap05
  38. ## 6 <split [130/51]> Bootstrap06
  39. ## 7 <split [130/57]> Bootstrap07
  40. ## 8 <split [130/51]> Bootstrap08
  41. ## 9 <split [130/44]> Bootstrap09
  42. ## 10 <split [130/53]> Bootstrap10
  43. ## # ... with 20 more rows
  1. rf_spec <- rand_forest(
  2. mode = "classification",
  3. mtry = tune(),
  4. trees = 1000,
  5. min_n = tune()
  6. ) %>%
  7. set_engine("ranger")
  8. rf_spec
  9. ## Random Forest Model Specification (classification)
  10. ##
  11. ## Main Arguments:
  12. ## mtry = tune()
  13. ## trees = 1000
  14. ## min_n = tune()
  15. ##
  16. ## Computational engine: ranger

调参

  1. doParallel::registerDoParallel()
  2. rf_grid <- tune_grid(
  3. rf_spec,
  4. asia ~ .,
  5. resamples = food_boot
  6. )
  7. ## i Creating pre-processing data to finalize unknown parameter: mtry
  8. rf_grid
  9. ## # Tuning results
  10. ## # Bootstrap sampling
  11. ## # A tibble: 30 x 4
  12. ## splits id .metrics .notes
  13. ## <list> <chr> <list> <list>
  14. ## 1 <split [130/48]> Bootstrap01 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  15. ## 2 <split [130/49]> Bootstrap02 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  16. ## 3 <split [130/49]> Bootstrap03 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  17. ## 4 <split [130/51]> Bootstrap04 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  18. ## 5 <split [130/47]> Bootstrap05 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  19. ## 6 <split [130/51]> Bootstrap06 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  20. ## 7 <split [130/57]> Bootstrap07 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  21. ## 8 <split [130/51]> Bootstrap08 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  22. ## 9 <split [130/44]> Bootstrap09 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  23. ## 10 <split [130/53]> Bootstrap10 <tibble[,6] [20 x 6]> <tibble[,1] [0 x 1]>
  24. ## # ... with 20 more rows

评价模型

  1. rf_grid %>% collect_metrics()
  2. ## # A tibble: 20 x 8
  3. ## mtry min_n .metric .estimator mean n std_err .config
  4. ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
  5. ## 1 11 15 accuracy binary 0.812 30 0.0113 Preprocessor1_Model01
  6. ## 2 11 15 roc_auc binary 0.823 30 0.0106 Preprocessor1_Model01
  7. ## 3 4 33 accuracy binary 0.813 30 0.00910 Preprocessor1_Model02
  8. ## 4 4 33 roc_auc binary 0.821 30 0.00995 Preprocessor1_Model02
  9. ## 5 5 31 accuracy binary 0.816 30 0.00837 Preprocessor1_Model03
  10. ## 6 5 31 roc_auc binary 0.820 30 0.0103 Preprocessor1_Model03
  11. ## 7 4 37 accuracy binary 0.817 30 0.00863 Preprocessor1_Model04
  12. ## 8 4 37 roc_auc binary 0.819 30 0.0105 Preprocessor1_Model04
  13. ## 9 6 9 accuracy binary 0.825 30 0.00908 Preprocessor1_Model05
  14. ## 10 6 9 roc_auc binary 0.833 30 0.00922 Preprocessor1_Model05
  15. ## 11 2 4 accuracy binary 0.830 30 0.00816 Preprocessor1_Model06
  16. ## 12 2 4 roc_auc binary 0.844 30 0.00975 Preprocessor1_Model06
  17. ## 13 2 12 accuracy binary 0.830 30 0.00774 Preprocessor1_Model07
  18. ## 14 2 12 roc_auc binary 0.836 30 0.00916 Preprocessor1_Model07
  19. ## 15 7 21 accuracy binary 0.816 30 0.00911 Preprocessor1_Model08
  20. ## 16 7 21 roc_auc binary 0.824 30 0.0102 Preprocessor1_Model08
  21. ## 17 8 18 accuracy binary 0.815 30 0.0102 Preprocessor1_Model09
  22. ## 18 8 18 roc_auc binary 0.825 30 0.0102 Preprocessor1_Model09
  23. ## 19 9 26 accuracy binary 0.813 30 0.00961 Preprocessor1_Model10
  24. ## 20 9 26 roc_auc binary 0.821 30 0.0108 Preprocessor1_Model10
  25. rf_grid %>% show_best("roc_auc")
  26. ## # A tibble: 5 x 8
  27. ## mtry min_n .metric .estimator mean n std_err .config
  28. ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
  29. ## 1 2 4 roc_auc binary 0.844 30 0.00975 Preprocessor1_Model06
  30. ## 2 2 12 roc_auc binary 0.836 30 0.00916 Preprocessor1_Model07
  31. ## 3 6 9 roc_auc binary 0.833 30 0.00922 Preprocessor1_Model05
  32. ## 4 8 18 roc_auc binary 0.825 30 0.0102 Preprocessor1_Model09
  33. ## 5 7 21 roc_auc binary 0.824 30 0.0102 Preprocessor1_Model08