数据探索
library(tidyverse)
## -- Attaching packages ----------------------------------------------------------- tidyverse 1.3.1 --
## v ggplot2 3.3.3 v purrr 0.3.4
## v tibble 3.1.1 v dplyr 1.0.6
## v tidyr 1.1.3 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.1
## -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
tuition_cost <- readr::read_csv("../datasets/tidytuesday/data/2020/2020-03-10/tuition_cost.csv")
##
## -- Column specification ----------------------------------------------------------------------------
## cols(
## name = col_character(),
## state = col_character(),
## state_code = col_character(),
## type = col_character(),
## degree_length = col_character(),
## room_and_board = col_double(),
## in_state_tuition = col_double(),
## in_state_total = col_double(),
## out_of_state_tuition = col_double(),
## out_of_state_total = col_double()
## )
diversity_raw <- readr::read_csv("../datasets/tidytuesday/data/2020/2020-03-10/diversity_school.csv") %>%
filter(category == "Total Minority") %>%
mutate(TotalMinority = enrollment / total_enrollment)
##
## -- Column specification ----------------------------------------------------------------------------
## cols(
## name = col_character(),
## total_enrollment = col_double(),
## state = col_character(),
## category = col_character(),
## enrollment = col_double()
## )
diversity_school <- diversity_raw %>%
filter(category == "Total Minority") %>%
mutate(TotalMinority = enrollment / total_enrollment)
diversity_school %>%
ggplot(aes(TotalMinority)) +
geom_histogram(alpha = 0.7, fill = "midnightblue") +
scale_x_continuous(labels = scales::percent_format()) +
labs(x = "% of student population who identifies as any minority")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
university_df <- diversity_school %>%
filter(category == "Total Minority") %>%
mutate(TotalMinority = enrollment / total_enrollment) %>%
transmute(
diversity = case_when(
TotalMinority > 0.3 ~ "high",
TRUE ~ "low"
),
name, state,
total_enrollment
) %>%
inner_join(tuition_cost %>%
select(
name, type, degree_length,
in_state_tuition:out_of_state_total
)) %>%
left_join(tibble(state = state.name, region = state.region)) %>%
select(-state, -name) %>%
mutate_if(is.character, factor)
## Joining, by = "name"
## Joining, by = "state"
skimr::skim(university_df)
Table: Data summary
Name | university_df |
Number of rows | 2159 |
Number of columns | 9 |
_ | |
Column type frequency: | |
factor | 4 |
numeric | 5 |
__ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
diversity | 0 | 1 | FALSE | 2 | low: 1241, hig: 918 |
type | 0 | 1 | FALSE | 3 | Pub: 1145, Pri: 955, For: 59 |
degree_length | 0 | 1 | FALSE | 2 | 4 Y: 1296, 2 Y: 863 |
region | 0 | 1 | FALSE | 4 | Sou: 774, Nor: 543, Nor: 443, Wes: 399 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
total_enrollment | 0 | 1 | 6183.76 | 8263.64 | 15 | 1352 | 3133 | 7644.5 | 81459 | ▇▁▁▁▁ |
in_state_tuition | 0 | 1 | 17044.02 | 15460.76 | 480 | 4695 | 10161 | 28780.0 | 59985 | ▇▂▂▁▁ |
in_state_total | 0 | 1 | 23544.64 | 19782.17 | 962 | 5552 | 17749 | 38519.0 | 75003 | ▇▅▂▂▁ |
out_of_state_tuition | 0 | 1 | 20797.98 | 13725.29 | 480 | 9298 | 17045 | 29865.0 | 59985 | ▇▆▅▂▁ |
out_of_state_total | 0 | 1 | 27298.60 | 18220.62 | 1376 | 11018 | 23036 | 40154.0 | 75003 | ▇▅▅▂▁ |
university_df %>%
ggplot(aes(type, in_state_tuition, fill = diversity)) +
geom_boxplot(alpha = 0.8) +
scale_y_continuous(labels = scales::dollar_format()) +
labs(x = NULL, y = "In-State Tuition", fill = "Diversity")
建立模型
library(tidymodels)
## -- Attaching packages ---------------------------------------------------------- tidymodels 0.1.3 --
## v broom 0.7.6 v rsample 0.0.9
## v dials 0.0.9 v tune 0.1.5
## v infer 0.5.4 v workflows 0.2.2
## v modeldata 0.1.0 v workflowsets 0.0.2
## v parsnip 0.1.5 v yardstick 0.0.8
## v recipes 0.1.16
## -- Conflicts ------------------------------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter() masks stats::filter()
## x recipes::fixed() masks stringr::fixed()
## x dplyr::lag() masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step() masks stats::step()
## * Use tidymodels_prefer() to resolve common conflicts.
set.seed(1234)
uni_split <- initial_split(university_df, strata = diversity)
uni_train <- training(uni_split)
uni_test <- testing(uni_split)
uni_rec <- recipe(diversity ~ ., data = uni_train) %>%
step_corr(all_numeric()) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_zv(all_numeric()) %>%
step_normalize(all_numeric())
uni_prep <- uni_rec %>%
prep()
uni_prep
## Data Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 8
##
## Training data contained 1620 data points and no missing data.
##
## Operations:
##
## Correlation filter removed in_state_tuition, ... [trained]
## Dummy variables from type, degree_length, region [trained]
## Zero variance filter removed no terms [trained]
## Centering and scaling for total_enrollment, ... [trained]
uni_juiced <- juice(uni_prep)
glm_spec <- logistic_reg() %>%
set_engine("glm")
glm_fit <- glm_spec %>%
fit(diversity ~ ., data = uni_juiced)
glm_fit
## parsnip model object
##
## Fit time: 20ms
##
## Call: stats::glm(formula = diversity ~ ., family = stats::binomial,
## data = data)
##
## Coefficients:
## (Intercept) total_enrollment out_of_state_total
## 0.3704 -0.4581 0.5074
## type_Private type_Public degree_length_X4.Year
## -0.1656 0.2058 0.2082
## region_South region_North.Central region_West
## -0.5175 0.3004 -0.5363
##
## Degrees of Freedom: 1619 Total (i.e. Null); 1611 Residual
## Null Deviance: 2210
## Residual Deviance: 1859 AIC: 1877
knn_spec <- nearest_neighbor() %>%
set_engine("kknn") %>%
set_mode("classification")
knn_fit <- knn_spec %>%
fit(diversity ~ ., data = uni_juiced)
knn_fit
## parsnip model object
##
## Fit time: 120ms
##
## Call:
## kknn::train.kknn(formula = diversity ~ ., data = data, ks = min_rows(5, data, 5))
##
## Type of response variable: nominal
## Minimal misclassification: 0.3277778
## Best kernel: optimal
## Best k: 5
tree_spec <- decision_tree() %>%
set_engine("rpart") %>%
set_mode("classification")
tree_fit <- tree_spec %>%
fit(diversity ~ ., data = uni_juiced)
tree_fit
## parsnip model object
##
## Fit time: 31ms
## n= 1620
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 1620 689 low (0.4253086 0.5746914)
## 2) region_North.Central< 0.5346496 1192 586 high (0.5083893 0.4916107)
## 4) out_of_state_total< -0.7087237 418 130 high (0.6889952 0.3110048) *
## 5) out_of_state_total>=-0.7087237 774 318 low (0.4108527 0.5891473)
## 10) out_of_state_total< 0.35164 362 180 low (0.4972376 0.5027624)
## 20) region_South>=0.3002561 212 86 high (0.5943396 0.4056604)
## 40) degree_length_X4.Year>=-0.2001293 172 62 high (0.6395349 0.3604651) *
## 41) degree_length_X4.Year< -0.2001293 40 16 low (0.4000000 0.6000000) *
## 21) region_South< 0.3002561 150 54 low (0.3600000 0.6400000)
## 42) region_West>=0.8128302 64 28 high (0.5625000 0.4375000) *
## 43) region_West< 0.8128302 86 18 low (0.2093023 0.7906977) *
## 11) out_of_state_total>=0.35164 412 138 low (0.3349515 0.6650485)
## 22) region_West>=0.8128302 88 38 high (0.5681818 0.4318182)
## 44) out_of_state_total>=1.547681 30 5 high (0.8333333 0.1666667) *
## 45) out_of_state_total< 1.547681 58 25 low (0.4310345 0.5689655) *
## 23) region_West< 0.8128302 324 88 low (0.2716049 0.7283951) *
## 3) region_North.Central>=0.5346496 428 83 low (0.1939252 0.8060748)
## 6) out_of_state_total< -1.19287 17 5 high (0.7058824 0.2941176) *
## 7) out_of_state_total>=-1.19287 411 71 low (0.1727494 0.8272506) *
评价模型
set.seed(123)
folds <- vfold_cv(uni_train, strata = diversity)
set.seed(234)
glm_rs <- glm_spec %>%
fit_resamples(
uni_rec,
folds,
metrics = metric_set(roc_auc, sens, spec),
control = control_resamples(save_pred = TRUE)
)
set.seed(234)
knn_rs <- knn_spec %>%
fit_resamples(
uni_rec,
folds,
metrics = metric_set(roc_auc, sens, spec),
control = control_resamples(save_pred = TRUE)
)
set.seed(234)
tree_rs <- tree_spec %>%
fit_resamples(
uni_rec,
folds,
metrics = metric_set(roc_auc, sens, spec),
control = control_resamples(save_pred = TRUE)
)
tree_rs
## # Resampling results
## # 10-fold cross-validation using stratification
## # A tibble: 10 x 5
## splits id .metrics .notes .predictions
## <list> <chr> <list> <list> <list>
## 1 <split [1457/1~ Fold01 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [163 x~
## 2 <split [1458/1~ Fold02 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
## 3 <split [1458/1~ Fold03 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
## 4 <split [1458/1~ Fold04 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
## 5 <split [1458/1~ Fold05 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
## 6 <split [1458/1~ Fold06 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
## 7 <split [1458/1~ Fold07 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
## 8 <split [1458/1~ Fold08 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
## 9 <split [1458/1~ Fold09 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [162 x~
## 10 <split [1459/1~ Fold10 <tibble[,4] [3 x~ <tibble[,1] [0 ~ <tibble[,6] [161 x~
glm_rs %>%
collect_metrics()
## # A tibble: 3 x 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 roc_auc binary 0.758 10 0.00891 Preprocessor1_Model1
## 2 sens binary 0.617 10 0.0179 Preprocessor1_Model1
## 3 spec binary 0.737 10 0.00677 Preprocessor1_Model1
knn_rs %>%
collect_metrics()
## # A tibble: 3 x 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 roc_auc binary 0.728 10 0.00652 Preprocessor1_Model1
## 2 sens binary 0.595 10 0.00978 Preprocessor1_Model1
## 3 spec binary 0.733 10 0.0121 Preprocessor1_Model1
tree_rs %>%
collect_metrics()
## # A tibble: 3 x 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 roc_auc binary 0.723 10 0.00578 Preprocessor1_Model1
## 2 sens binary 0.642 10 0.0182 Preprocessor1_Model1
## 3 spec binary 0.745 10 0.00941 Preprocessor1_Model1
glm_rs %>%
unnest(.predictions) %>%
mutate(model = "glm") %>%
bind_rows(knn_rs %>%
unnest(.predictions) %>%
mutate(model = "knn")) %>%
bind_rows(tree_rs %>%
unnest(.predictions) %>%
mutate(model = "rpart")) %>%
group_by(model) %>%
roc_curve(diversity, .pred_high) %>%
ggplot(aes(x = 1 - specificity, y = sensitivity, color = model)) +
geom_line(size = 1.5) +
geom_abline(
lty = 2, alpha = 0.5,
color = "gray50",
size = 1.2
)
glm_fit %>%
tidy() %>%
arrange(-estimate)
## # A tibble: 9 x 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 out_of_state_total 0.507 0.0934 5.43 5.58e- 8
## 2 (Intercept) 0.370 0.0572 6.47 9.72e-11
## 3 region_North.Central 0.300 0.0825 3.64 2.72e- 4
## 4 degree_length_X4.Year 0.208 0.0863 2.41 1.58e- 2
## 5 type_Public 0.206 0.193 1.07 2.86e- 1
## 6 type_Private -0.166 0.197 -0.838 4.02e- 1
## 7 total_enrollment -0.458 0.0737 -6.22 5.07e-10
## 8 region_South -0.517 0.0782 -6.62 3.64e-11
## 9 region_West -0.536 0.0724 -7.41 1.27e-13
glm_fit %>%
predict(
new_data = bake(uni_prep, uni_test),
type = "prob"
) %>%
mutate(truth = uni_test$diversity) %>%
roc_auc(truth, .pred_high)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 roc_auc binary 0.756
glm_fit %>%
predict(
new_data = bake(uni_prep, new_data = uni_test),
type = "class"
) %>%
mutate(truth = uni_test$diversity) %>%
spec(truth, .pred_class)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 spec binary 0.719