数据探索
library(tidyverse)## -- Attaching packages ----------------------------------------------------------- tidyverse 1.3.1 --## v ggplot2 3.3.3 v purrr 0.3.4## v tibble 3.1.1 v dplyr 1.0.6## v tidyr 1.1.3 v stringr 1.4.0## v readr 1.4.0 v forcats 0.5.1## -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --## x dplyr::filter() masks stats::filter()## x dplyr::lag() masks stats::lag()tuition_cost <- readr::read_csv("../datasets/tidytuesday/data/2020/2020-03-10/tuition_cost.csv")#### -- Column specification ----------------------------------------------------------------------------## cols(## name = col_character(),## state = col_character(),## state_code = col_character(),## type = col_character(),## degree_length = col_character(),## room_and_board = col_double(),## in_state_tuition = col_double(),## in_state_total = col_double(),## out_of_state_tuition = col_double(),## out_of_state_total = col_double()## )diversity_raw <- readr::read_csv("../datasets/tidytuesday/data/2020/2020-03-10/diversity_school.csv") %>%filter(category == "Total Minority") %>%mutate(TotalMinority = enrollment / total_enrollment)#### -- Column specification ----------------------------------------------------------------------------## cols(## name = col_character(),## total_enrollment = col_double(),## state = col_character(),## category = col_character(),## enrollment = col_double()## )
diversity_school <- diversity_raw %>%filter(category == "Total Minority") %>%mutate(TotalMinority = enrollment / total_enrollment)diversity_school %>%ggplot(aes(TotalMinority)) +geom_histogram(alpha = 0.7, fill = "midnightblue") +scale_x_continuous(labels = scales::percent_format()) +labs(x = "% of student population who identifies as any minority")## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

university_df <- diversity_school %>%filter(category == "Total Minority") %>%mutate(TotalMinority = enrollment / total_enrollment) %>%transmute(diversity = case_when(TotalMinority > 0.3 ~ "high",TRUE ~ "low"),name, state,total_enrollment) %>%inner_join(tuition_cost %>%select(name, type, degree_length,in_state_tuition:out_of_state_total)) %>%left_join(tibble(state = state.name, region = state.region)) %>%select(-state, -name) %>%mutate_if(is.character, factor)## Joining, by = "name"## Joining, by = "state"skimr::skim(university_df)
Table: Data summary
| Name | university_df |
| Number of rows | 2159 |
| Number of columns | 9 |
| _ | |
| Column type frequency: | |
| factor | 4 |
| numeric | 5 |
| __ | |
| Group variables | None |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| diversity | 0 | 1 | FALSE | 2 | low: 1241, hig: 918 |
| type | 0 | 1 | FALSE | 3 | Pub: 1145, Pri: 955, For: 59 |
| degree_length | 0 | 1 | FALSE | 2 | 4 Y: 1296, 2 Y: 863 |
| region | 0 | 1 | FALSE | 4 | Sou: 774, Nor: 543, Nor: 443, Wes: 399 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| total_enrollment | 0 | 1 | 6183.76 | 8263.64 | 15 | 1352 | 3133 | 7644.5 | 81459 | ▇▁▁▁▁ |
| in_state_tuition | 0 | 1 | 17044.02 | 15460.76 | 480 | 4695 | 10161 | 28780.0 | 59985 | ▇▂▂▁▁ |
| in_state_total | 0 | 1 | 23544.64 | 19782.17 | 962 | 5552 | 17749 | 38519.0 | 75003 | ▇▅▂▂▁ |
| out_of_state_tuition | 0 | 1 | 20797.98 | 13725.29 | 480 | 9298 | 17045 | 29865.0 | 59985 | ▇▆▅▂▁ |
| out_of_state_total | 0 | 1 | 27298.60 | 18220.62 | 1376 | 11018 | 23036 | 40154.0 | 75003 | ▇▅▅▂▁ |
university_df %>%ggplot(aes(type, in_state_tuition, fill = diversity)) +geom_boxplot(alpha = 0.8) +scale_y_continuous(labels = scales::dollar_format()) +labs(x = NULL, y = "In-State Tuition", fill = "Diversity")

建立模型
library(tidymodels)## -- Attaching packages ---------------------------------------------------------- tidymodels 0.1.3 --## v broom 0.7.6 v rsample 0.0.9## v dials 0.0.9 v tune 0.1.5## v infer 0.5.4 v workflows 0.2.2## v modeldata 0.1.0 v workflowsets 0.0.2## v parsnip 0.1.5 v yardstick 0.0.8## v recipes 0.1.16## -- Conflicts ------------------------------------------------------------- tidymodels_conflicts() --## x scales::discard() masks purrr::discard()## x dplyr::filter() masks stats::filter()## x recipes::fixed() masks stringr::fixed()## x dplyr::lag() masks stats::lag()## x yardstick::spec() masks readr::spec()## x recipes::step() masks stats::step()## * Use tidymodels_prefer() to resolve common conflicts.set.seed(1234)uni_split <- initial_split(university_df, strata = diversity)uni_train <- training(uni_split)uni_test <- testing(uni_split)uni_rec <- recipe(diversity ~ ., data = uni_train) %>%step_corr(all_numeric()) %>%step_dummy(all_nominal(), -all_outcomes()) %>%step_zv(all_numeric()) %>%step_normalize(all_numeric())uni_prep <- uni_rec %>%prep()uni_prep## Data Recipe#### Inputs:#### role #variables## outcome 1## predictor 8#### Training data contained 1620 data points and no missing data.#### Operations:#### Correlation filter removed in_state_tuition, ... [trained]## Dummy variables from type, degree_length, region [trained]## Zero variance filter removed no terms [trained]## Centering and scaling for total_enrollment, ... [trained]
uni_juiced <- juice(uni_prep)glm_spec <- logistic_reg() %>%set_engine("glm")glm_fit <- glm_spec %>%fit(diversity ~ ., data = uni_juiced)glm_fit## parsnip model object#### Fit time: 20ms#### Call: stats::glm(formula = diversity ~ ., family = stats::binomial,## data = data)#### Coefficients:## (Intercept) total_enrollment out_of_state_total## 0.3704 -0.4581 0.5074## type_Private type_Public degree_length_X4.Year## -0.1656 0.2058 0.2082## region_South region_North.Central region_West## -0.5175 0.3004 -0.5363#### Degrees of Freedom: 1619 Total (i.e. Null); 1611 Residual## Null Deviance: 2210## Residual Deviance: 1859 AIC: 1877
knn_spec <- nearest_neighbor() %>%set_engine("kknn") %>%set_mode("classification")knn_fit <- knn_spec %>%fit(diversity ~ ., data = uni_juiced)knn_fit## parsnip model object#### Fit time: 120ms#### Call:## kknn::train.kknn(formula = diversity ~ ., data = data, ks = min_rows(5, data, 5))#### Type of response variable: nominal## Minimal misclassification: 0.3277778## Best kernel: optimal## Best k: 5
tree_spec <- decision_tree() %>%set_engine("rpart") %>%set_mode("classification")tree_fit <- tree_spec %>%fit(diversity ~ ., data = uni_juiced)tree_fit## parsnip model object#### Fit time: 31ms## n= 1620#### node), split, n, loss, yval, (yprob)## * denotes terminal node#### 1) root 1620 689 low (0.4253086 0.5746914)## 2) region_North.Central< 0.5346496 1192 586 high (0.5083893 0.4916107)## 4) out_of_state_total< -0.7087237 418 130 high (0.6889952 0.3110048) *## 5) out_of_state_total>=-0.7087237 774 318 low (0.4108527 0.5891473)## 10) out_of_state_total< 0.35164 362 180 low (0.4972376 0.5027624)## 20) region_South>=0.3002561 212 86 high (0.5943396 0.4056604)## 40) degree_length_X4.Year>=-0.2001293 172 62 high (0.6395349 0.3604651) *## 41) degree_length_X4.Year< -0.2001293 40 16 low (0.4000000 0.6000000) *## 21) region_South< 0.3002561 150 54 low (0.3600000 0.6400000)## 42) region_West>=0.8128302 64 28 high (0.5625000 0.4375000) *## 43) region_West< 0.8128302 86 18 low (0.2093023 0.7906977) *## 11) out_of_state_total>=0.35164 412 138 low (0.3349515 0.6650485)## 22) region_West>=0.8128302 88 38 high (0.5681818 0.4318182)## 44) out_of_state_total>=1.547681 30 5 high (0.8333333 0.1666667) *## 45) out_of_state_total< 1.547681 58 25 low (0.4310345 0.5689655) *## 23) region_West< 0.8128302 324 88 low (0.2716049 0.7283951) *## 3) region_North.Central>=0.5346496 428 83 low (0.1939252 0.8060748)## 6) out_of_state_total< -1.19287 17 5 high (0.7058824 0.2941176) *## 7) out_of_state_total>=-1.19287 411 71 low (0.1727494 0.8272506) *
评价模型
set.seed(123)folds <- vfold_cv(uni_train, strata = diversity)
set.seed(234)glm_rs <- glm_spec %>%fit_resamples(uni_rec,folds,metrics = metric_set(roc_auc, sens, spec),control = control_resamples(save_pred = TRUE))set.seed(234)knn_rs <- knn_spec %>%fit_resamples(uni_rec,folds,metrics = metric_set(roc_auc, sens, spec),control = control_resamples(save_pred = TRUE))set.seed(234)tree_rs <- tree_spec %>%fit_resamples(uni_rec,folds,metrics = metric_set(roc_auc, sens, spec),control = control_resamples(save_pred = TRUE))
tree_rs## # Resampling results## # 10-fold cross-validation using stratification## # A tibble: 10 x 5## splits id .metrics .notes .predictions## <list> <chr> <list> <list> <list>## 1 <split [1457/1~ Fold01 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [163 x~## 2 <split [1458/1~ Fold02 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~## 3 <split [1458/1~ Fold03 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~## 4 <split [1458/1~ Fold04 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~## 5 <split [1458/1~ Fold05 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~## 6 <split [1458/1~ Fold06 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~## 7 <split [1458/1~ Fold07 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~## 8 <split [1458/1~ Fold08 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~## 9 <split [1458/1~ Fold09 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~## 10 <split [1459/1~ Fold10 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [161 x~
glm_rs %>%collect_metrics()## # A tibble: 3 x 6## .metric .estimator mean n std_err .config## <chr> <chr> <dbl> <int> <dbl> <chr>## 1 roc_auc binary 0.758 10 0.00891 Preprocessor1_Model1## 2 sens binary 0.617 10 0.0179 Preprocessor1_Model1## 3 spec binary 0.737 10 0.00677 Preprocessor1_Model1
knn_rs %>%collect_metrics()## # A tibble: 3 x 6## .metric .estimator mean n std_err .config## <chr> <chr> <dbl> <int> <dbl> <chr>## 1 roc_auc binary 0.728 10 0.00652 Preprocessor1_Model1## 2 sens binary 0.595 10 0.00978 Preprocessor1_Model1## 3 spec binary 0.733 10 0.0121 Preprocessor1_Model1
tree_rs %>%collect_metrics()## # A tibble: 3 x 6## .metric .estimator mean n std_err .config## <chr> <chr> <dbl> <int> <dbl> <chr>## 1 roc_auc binary 0.723 10 0.00578 Preprocessor1_Model1## 2 sens binary 0.642 10 0.0182 Preprocessor1_Model1## 3 spec binary 0.745 10 0.00941 Preprocessor1_Model1
glm_rs %>%unnest(.predictions) %>%mutate(model = "glm") %>%bind_rows(knn_rs %>%unnest(.predictions) %>%mutate(model = "knn")) %>%bind_rows(tree_rs %>%unnest(.predictions) %>%mutate(model = "rpart")) %>%group_by(model) %>%roc_curve(diversity, .pred_high) %>%ggplot(aes(x = 1 - specificity, y = sensitivity, color = model)) +geom_line(size = 1.5) +geom_abline(lty = 2, alpha = 0.5,color = "gray50",size = 1.2)

glm_fit %>%tidy() %>%arrange(-estimate)## # A tibble: 9 x 5## term estimate std.error statistic p.value## <chr> <dbl> <dbl> <dbl> <dbl>## 1 out_of_state_total 0.507 0.0934 5.43 5.58e- 8## 2 (Intercept) 0.370 0.0572 6.47 9.72e-11## 3 region_North.Central 0.300 0.0825 3.64 2.72e- 4## 4 degree_length_X4.Year 0.208 0.0863 2.41 1.58e- 2## 5 type_Public 0.206 0.193 1.07 2.86e- 1## 6 type_Private -0.166 0.197 -0.838 4.02e- 1## 7 total_enrollment -0.458 0.0737 -6.22 5.07e-10## 8 region_South -0.517 0.0782 -6.62 3.64e-11## 9 region_West -0.536 0.0724 -7.41 1.27e-13
glm_fit %>%predict(new_data = bake(uni_prep, uni_test),type = "prob") %>%mutate(truth = uni_test$diversity) %>%roc_auc(truth, .pred_high)## # A tibble: 1 x 3## .metric .estimator .estimate## <chr> <chr> <dbl>## 1 roc_auc binary 0.756
glm_fit %>%predict(new_data = bake(uni_prep, new_data = uni_test),type = "class") %>%mutate(truth = uni_test$diversity) %>%spec(truth, .pred_class)## # A tibble: 1 x 3## .metric .estimator .estimate## <chr> <chr> <dbl>## 1 spec binary 0.719
