数据探索

  1. library(tidyverse)
  2. ## -- Attaching packages ----------------------------------------------------------- tidyverse 1.3.1 --
  3. ## v ggplot2 3.3.3 v purrr 0.3.4
  4. ## v tibble 3.1.1 v dplyr 1.0.6
  5. ## v tidyr 1.1.3 v stringr 1.4.0
  6. ## v readr 1.4.0 v forcats 0.5.1
  7. ## -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --
  8. ## x dplyr::filter() masks stats::filter()
  9. ## x dplyr::lag() masks stats::lag()
  10. tuition_cost <- readr::read_csv("../datasets/tidytuesday/data/2020/2020-03-10/tuition_cost.csv")
  11. ##
  12. ## -- Column specification ----------------------------------------------------------------------------
  13. ## cols(
  14. ## name = col_character(),
  15. ## state = col_character(),
  16. ## state_code = col_character(),
  17. ## type = col_character(),
  18. ## degree_length = col_character(),
  19. ## room_and_board = col_double(),
  20. ## in_state_tuition = col_double(),
  21. ## in_state_total = col_double(),
  22. ## out_of_state_tuition = col_double(),
  23. ## out_of_state_total = col_double()
  24. ## )
  25. diversity_raw <- readr::read_csv("../datasets/tidytuesday/data/2020/2020-03-10/diversity_school.csv") %>%
  26. filter(category == "Total Minority") %>%
  27. mutate(TotalMinority = enrollment / total_enrollment)
  28. ##
  29. ## -- Column specification ----------------------------------------------------------------------------
  30. ## cols(
  31. ## name = col_character(),
  32. ## total_enrollment = col_double(),
  33. ## state = col_character(),
  34. ## category = col_character(),
  35. ## enrollment = col_double()
  36. ## )
  1. diversity_school <- diversity_raw %>%
  2. filter(category == "Total Minority") %>%
  3. mutate(TotalMinority = enrollment / total_enrollment)
  4. diversity_school %>%
  5. ggplot(aes(TotalMinority)) +
  6. geom_histogram(alpha = 0.7, fill = "midnightblue") +
  7. scale_x_continuous(labels = scales::percent_format()) +
  8. labs(x = "% of student population who identifies as any minority")
  9. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

tidymodels-exercise-04 - 图1

  1. university_df <- diversity_school %>%
  2. filter(category == "Total Minority") %>%
  3. mutate(TotalMinority = enrollment / total_enrollment) %>%
  4. transmute(
  5. diversity = case_when(
  6. TotalMinority > 0.3 ~ "high",
  7. TRUE ~ "low"
  8. ),
  9. name, state,
  10. total_enrollment
  11. ) %>%
  12. inner_join(tuition_cost %>%
  13. select(
  14. name, type, degree_length,
  15. in_state_tuition:out_of_state_total
  16. )) %>%
  17. left_join(tibble(state = state.name, region = state.region)) %>%
  18. select(-state, -name) %>%
  19. mutate_if(is.character, factor)
  20. ## Joining, by = "name"
  21. ## Joining, by = "state"
  22. skimr::skim(university_df)

Table: Data summary

Name university_df
Number of rows 2159
Number of columns 9
_
Column type frequency:
factor 4
numeric 5
__
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
diversity 0 1 FALSE 2 low: 1241, hig: 918
type 0 1 FALSE 3 Pub: 1145, Pri: 955, For: 59
degree_length 0 1 FALSE 2 4 Y: 1296, 2 Y: 863
region 0 1 FALSE 4 Sou: 774, Nor: 543, Nor: 443, Wes: 399

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
total_enrollment 0 1 6183.76 8263.64 15 1352 3133 7644.5 81459 ▇▁▁▁▁
in_state_tuition 0 1 17044.02 15460.76 480 4695 10161 28780.0 59985 ▇▂▂▁▁
in_state_total 0 1 23544.64 19782.17 962 5552 17749 38519.0 75003 ▇▅▂▂▁
out_of_state_tuition 0 1 20797.98 13725.29 480 9298 17045 29865.0 59985 ▇▆▅▂▁
out_of_state_total 0 1 27298.60 18220.62 1376 11018 23036 40154.0 75003 ▇▅▅▂▁
  1. university_df %>%
  2. ggplot(aes(type, in_state_tuition, fill = diversity)) +
  3. geom_boxplot(alpha = 0.8) +
  4. scale_y_continuous(labels = scales::dollar_format()) +
  5. labs(x = NULL, y = "In-State Tuition", fill = "Diversity")

tidymodels-exercise-04 - 图2

建立模型

  1. library(tidymodels)
  2. ## -- Attaching packages ---------------------------------------------------------- tidymodels 0.1.3 --
  3. ## v broom 0.7.6 v rsample 0.0.9
  4. ## v dials 0.0.9 v tune 0.1.5
  5. ## v infer 0.5.4 v workflows 0.2.2
  6. ## v modeldata 0.1.0 v workflowsets 0.0.2
  7. ## v parsnip 0.1.5 v yardstick 0.0.8
  8. ## v recipes 0.1.16
  9. ## -- Conflicts ------------------------------------------------------------- tidymodels_conflicts() --
  10. ## x scales::discard() masks purrr::discard()
  11. ## x dplyr::filter() masks stats::filter()
  12. ## x recipes::fixed() masks stringr::fixed()
  13. ## x dplyr::lag() masks stats::lag()
  14. ## x yardstick::spec() masks readr::spec()
  15. ## x recipes::step() masks stats::step()
  16. ## * Use tidymodels_prefer() to resolve common conflicts.
  17. set.seed(1234)
  18. uni_split <- initial_split(university_df, strata = diversity)
  19. uni_train <- training(uni_split)
  20. uni_test <- testing(uni_split)
  21. uni_rec <- recipe(diversity ~ ., data = uni_train) %>%
  22. step_corr(all_numeric()) %>%
  23. step_dummy(all_nominal(), -all_outcomes()) %>%
  24. step_zv(all_numeric()) %>%
  25. step_normalize(all_numeric())
  26. uni_prep <- uni_rec %>%
  27. prep()
  28. uni_prep
  29. ## Data Recipe
  30. ##
  31. ## Inputs:
  32. ##
  33. ## role #variables
  34. ## outcome 1
  35. ## predictor 8
  36. ##
  37. ## Training data contained 1620 data points and no missing data.
  38. ##
  39. ## Operations:
  40. ##
  41. ## Correlation filter removed in_state_tuition, ... [trained]
  42. ## Dummy variables from type, degree_length, region [trained]
  43. ## Zero variance filter removed no terms [trained]
  44. ## Centering and scaling for total_enrollment, ... [trained]
  1. uni_juiced <- juice(uni_prep)
  2. glm_spec <- logistic_reg() %>%
  3. set_engine("glm")
  4. glm_fit <- glm_spec %>%
  5. fit(diversity ~ ., data = uni_juiced)
  6. glm_fit
  7. ## parsnip model object
  8. ##
  9. ## Fit time: 20ms
  10. ##
  11. ## Call: stats::glm(formula = diversity ~ ., family = stats::binomial,
  12. ## data = data)
  13. ##
  14. ## Coefficients:
  15. ## (Intercept) total_enrollment out_of_state_total
  16. ## 0.3704 -0.4581 0.5074
  17. ## type_Private type_Public degree_length_X4.Year
  18. ## -0.1656 0.2058 0.2082
  19. ## region_South region_North.Central region_West
  20. ## -0.5175 0.3004 -0.5363
  21. ##
  22. ## Degrees of Freedom: 1619 Total (i.e. Null); 1611 Residual
  23. ## Null Deviance: 2210
  24. ## Residual Deviance: 1859 AIC: 1877
  1. knn_spec <- nearest_neighbor() %>%
  2. set_engine("kknn") %>%
  3. set_mode("classification")
  4. knn_fit <- knn_spec %>%
  5. fit(diversity ~ ., data = uni_juiced)
  6. knn_fit
  7. ## parsnip model object
  8. ##
  9. ## Fit time: 120ms
  10. ##
  11. ## Call:
  12. ## kknn::train.kknn(formula = diversity ~ ., data = data, ks = min_rows(5, data, 5))
  13. ##
  14. ## Type of response variable: nominal
  15. ## Minimal misclassification: 0.3277778
  16. ## Best kernel: optimal
  17. ## Best k: 5
  1. tree_spec <- decision_tree() %>%
  2. set_engine("rpart") %>%
  3. set_mode("classification")
  4. tree_fit <- tree_spec %>%
  5. fit(diversity ~ ., data = uni_juiced)
  6. tree_fit
  7. ## parsnip model object
  8. ##
  9. ## Fit time: 31ms
  10. ## n= 1620
  11. ##
  12. ## node), split, n, loss, yval, (yprob)
  13. ## * denotes terminal node
  14. ##
  15. ## 1) root 1620 689 low (0.4253086 0.5746914)
  16. ## 2) region_North.Central< 0.5346496 1192 586 high (0.5083893 0.4916107)
  17. ## 4) out_of_state_total< -0.7087237 418 130 high (0.6889952 0.3110048) *
  18. ## 5) out_of_state_total>=-0.7087237 774 318 low (0.4108527 0.5891473)
  19. ## 10) out_of_state_total< 0.35164 362 180 low (0.4972376 0.5027624)
  20. ## 20) region_South>=0.3002561 212 86 high (0.5943396 0.4056604)
  21. ## 40) degree_length_X4.Year>=-0.2001293 172 62 high (0.6395349 0.3604651) *
  22. ## 41) degree_length_X4.Year< -0.2001293 40 16 low (0.4000000 0.6000000) *
  23. ## 21) region_South< 0.3002561 150 54 low (0.3600000 0.6400000)
  24. ## 42) region_West>=0.8128302 64 28 high (0.5625000 0.4375000) *
  25. ## 43) region_West< 0.8128302 86 18 low (0.2093023 0.7906977) *
  26. ## 11) out_of_state_total>=0.35164 412 138 low (0.3349515 0.6650485)
  27. ## 22) region_West>=0.8128302 88 38 high (0.5681818 0.4318182)
  28. ## 44) out_of_state_total>=1.547681 30 5 high (0.8333333 0.1666667) *
  29. ## 45) out_of_state_total< 1.547681 58 25 low (0.4310345 0.5689655) *
  30. ## 23) region_West< 0.8128302 324 88 low (0.2716049 0.7283951) *
  31. ## 3) region_North.Central>=0.5346496 428 83 low (0.1939252 0.8060748)
  32. ## 6) out_of_state_total< -1.19287 17 5 high (0.7058824 0.2941176) *
  33. ## 7) out_of_state_total>=-1.19287 411 71 low (0.1727494 0.8272506) *

评价模型

  1. set.seed(123)
  2. folds <- vfold_cv(uni_train, strata = diversity)
  1. set.seed(234)
  2. glm_rs <- glm_spec %>%
  3. fit_resamples(
  4. uni_rec,
  5. folds,
  6. metrics = metric_set(roc_auc, sens, spec),
  7. control = control_resamples(save_pred = TRUE)
  8. )
  9. set.seed(234)
  10. knn_rs <- knn_spec %>%
  11. fit_resamples(
  12. uni_rec,
  13. folds,
  14. metrics = metric_set(roc_auc, sens, spec),
  15. control = control_resamples(save_pred = TRUE)
  16. )
  17. set.seed(234)
  18. tree_rs <- tree_spec %>%
  19. fit_resamples(
  20. uni_rec,
  21. folds,
  22. metrics = metric_set(roc_auc, sens, spec),
  23. control = control_resamples(save_pred = TRUE)
  24. )
  1. tree_rs
  2. ## # Resampling results
  3. ## # 10-fold cross-validation using stratification
  4. ## # A tibble: 10 x 5
  5. ## splits id .metrics .notes .predictions
  6. ## <list> <chr> <list> <list> <list>
  7. ## 1 <split [1457/1~ Fold01 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [163 x~
  8. ## 2 <split [1458/1~ Fold02 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
  9. ## 3 <split [1458/1~ Fold03 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
  10. ## 4 <split [1458/1~ Fold04 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
  11. ## 5 <split [1458/1~ Fold05 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
  12. ## 6 <split [1458/1~ Fold06 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
  13. ## 7 <split [1458/1~ Fold07 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
  14. ## 8 <split [1458/1~ Fold08 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
  15. ## 9 <split [1458/1~ Fold09 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
  16. ## 10 <split [1459/1~ Fold10 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [161 x~
  1. glm_rs %>%
  2. collect_metrics()
  3. ## # A tibble: 3 x 6
  4. ## .metric .estimator mean n std_err .config
  5. ## <chr> <chr> <dbl> <int> <dbl> <chr>
  6. ## 1 roc_auc binary 0.758 10 0.00891 Preprocessor1_Model1
  7. ## 2 sens binary 0.617 10 0.0179 Preprocessor1_Model1
  8. ## 3 spec binary 0.737 10 0.00677 Preprocessor1_Model1
  1. knn_rs %>%
  2. collect_metrics()
  3. ## # A tibble: 3 x 6
  4. ## .metric .estimator mean n std_err .config
  5. ## <chr> <chr> <dbl> <int> <dbl> <chr>
  6. ## 1 roc_auc binary 0.728 10 0.00652 Preprocessor1_Model1
  7. ## 2 sens binary 0.595 10 0.00978 Preprocessor1_Model1
  8. ## 3 spec binary 0.733 10 0.0121 Preprocessor1_Model1
  1. tree_rs %>%
  2. collect_metrics()
  3. ## # A tibble: 3 x 6
  4. ## .metric .estimator mean n std_err .config
  5. ## <chr> <chr> <dbl> <int> <dbl> <chr>
  6. ## 1 roc_auc binary 0.723 10 0.00578 Preprocessor1_Model1
  7. ## 2 sens binary 0.642 10 0.0182 Preprocessor1_Model1
  8. ## 3 spec binary 0.745 10 0.00941 Preprocessor1_Model1
  1. glm_rs %>%
  2. unnest(.predictions) %>%
  3. mutate(model = "glm") %>%
  4. bind_rows(knn_rs %>%
  5. unnest(.predictions) %>%
  6. mutate(model = "knn")) %>%
  7. bind_rows(tree_rs %>%
  8. unnest(.predictions) %>%
  9. mutate(model = "rpart")) %>%
  10. group_by(model) %>%
  11. roc_curve(diversity, .pred_high) %>%
  12. ggplot(aes(x = 1 - specificity, y = sensitivity, color = model)) +
  13. geom_line(size = 1.5) +
  14. geom_abline(
  15. lty = 2, alpha = 0.5,
  16. color = "gray50",
  17. size = 1.2
  18. )

tidymodels-exercise-04 - 图3

  1. glm_fit %>%
  2. tidy() %>%
  3. arrange(-estimate)
  4. ## # A tibble: 9 x 5
  5. ## term estimate std.error statistic p.value
  6. ## <chr> <dbl> <dbl> <dbl> <dbl>
  7. ## 1 out_of_state_total 0.507 0.0934 5.43 5.58e- 8
  8. ## 2 (Intercept) 0.370 0.0572 6.47 9.72e-11
  9. ## 3 region_North.Central 0.300 0.0825 3.64 2.72e- 4
  10. ## 4 degree_length_X4.Year 0.208 0.0863 2.41 1.58e- 2
  11. ## 5 type_Public 0.206 0.193 1.07 2.86e- 1
  12. ## 6 type_Private -0.166 0.197 -0.838 4.02e- 1
  13. ## 7 total_enrollment -0.458 0.0737 -6.22 5.07e-10
  14. ## 8 region_South -0.517 0.0782 -6.62 3.64e-11
  15. ## 9 region_West -0.536 0.0724 -7.41 1.27e-13
  1. glm_fit %>%
  2. predict(
  3. new_data = bake(uni_prep, uni_test),
  4. type = "prob"
  5. ) %>%
  6. mutate(truth = uni_test$diversity) %>%
  7. roc_auc(truth, .pred_high)
  8. ## # A tibble: 1 x 3
  9. ## .metric .estimator .estimate
  10. ## <chr> <chr> <dbl>
  11. ## 1 roc_auc binary 0.756
  1. glm_fit %>%
  2. predict(
  3. new_data = bake(uni_prep, new_data = uni_test),
  4. type = "class"
  5. ) %>%
  6. mutate(truth = uni_test$diversity) %>%
  7. spec(truth, .pred_class)
  8. ## # A tibble: 1 x 3
  9. ## .metric .estimator .estimate
  10. ## <chr> <chr> <dbl>
  11. ## 1 spec binary 0.719