数据探索
library(tidyverse)
## -- Attaching packages ----------------------------------------------------------- tidyverse 1.3.1 --
## v ggplot2 3.3.3 v purrr 0.3.4
## v tibble 3.1.1 v dplyr 1.0.6
## v tidyr 1.1.3 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.1
## -- Conflicts -------------------------------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
sf_trees <- read_csv("../datasets/tidytuesday/data/2020/2020-01-28/sf_trees.csv")
##
## -- Column specification ----------------------------------------------------------------------------
## cols(
## tree_id = col_double(),
## legal_status = col_character(),
## species = col_character(),
## address = col_character(),
## site_order = col_double(),
## site_info = col_character(),
## caretaker = col_character(),
## date = col_date(format = ""),
## dbh = col_double(),
## plot_size = col_character(),
## latitude = col_double(),
## longitude = col_double()
## )
trees_df <- sf_trees %>%
mutate(
legal_status = case_when(
legal_status == "DPW Maintained" ~ legal_status,
TRUE ~ "Other"
),
plot_size = parse_number(plot_size)
) %>%
select(-address) %>%
na.omit() %>%
mutate_if(is.character, factor)
trees_df %>%
ggplot(aes(longitude, latitude, color = legal_status)) +
geom_point(size = 0.5, alpha = 0.4) +
labs(color = NULL) +
theme_bw()+
theme(panel.border = element_blank())

trees_df %>%
count(legal_status, caretaker) %>%
add_count(caretaker, wt = n, name = "caretaker_count") %>%
filter(caretaker_count > 50) %>%
group_by(legal_status) %>%
mutate(percent_legal = n / sum(n)) %>%
ggplot(aes(percent_legal, caretaker, fill = legal_status)) +
geom_col(position = "dodge") +
labs(
fill = NULL,
x = "% of trees in each category"
)

建立模型
library(tidymodels)
## -- Attaching packages ---------------------------------------------------------- tidymodels 0.1.3 --
## v broom 0.7.6 v rsample 0.0.9
## v dials 0.0.9 v tune 0.1.5
## v infer 0.5.4 v workflows 0.2.2
## v modeldata 0.1.0 v workflowsets 0.0.2
## v parsnip 0.1.5 v yardstick 0.0.8
## v recipes 0.1.16
## -- Conflicts ------------------------------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter() masks stats::filter()
## x recipes::fixed() masks stringr::fixed()
## x dplyr::lag() masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step() masks stats::step()
## * Use tidymodels_prefer() to resolve common conflicts.
trees_df <- trees_df[1:2000,]
set.seed(123)
trees_split <- initial_split(trees_df, strata = legal_status)
trees_train <- training(trees_split)
trees_test <- testing(trees_split)
tree_rec <- recipe(legal_status ~ ., data = trees_train) %>%
update_role(tree_id, new_role = "ID") %>%
step_other(species, caretaker, threshold = 0.01) %>%
step_other(site_info, threshold = 0.005) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_date(date, features = c("year")) %>%
step_rm(date) %>%
step_downsample(legal_status)
tree_prep <- prep(tree_rec)
juiced <- juice(tree_prep)
tune_spec <- rand_forest(
mtry = tune(),
trees = 1000,
min_n = tune()
) %>%
set_mode("classification") %>%
set_engine("ranger")
tune_wf <- workflow() %>%
add_recipe(tree_rec) %>%
add_model(tune_spec)
调整参数
set.seed(234)
trees_folds <- vfold_cv(trees_train,v = 5)
doParallel::registerDoParallel()
set.seed(345)
tune_res <- tune_grid(
tune_wf,
resamples = trees_folds,
grid = 5
)
## i Creating pre-processing data to finalize unknown parameter: mtry
tune_res
## # Tuning results
## # 5-fold cross-validation
## # A tibble: 5 x 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [1200/301]> Fold1 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
## 2 <split [1201/300]> Fold2 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
## 3 <split [1201/300]> Fold3 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
## 4 <split [1201/300]> Fold4 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
## 5 <split [1201/300]> Fold5 <tibble[,6] [10 x 6]> <tibble[,1] [0 x 1]>
tune_res %>%
collect_metrics() %>%
filter(.metric == "roc_auc") %>%
select(mean, min_n, mtry) %>%
pivot_longer(min_n:mtry,
values_to = "value",
names_to = "parameter"
) %>%
ggplot(aes(value, mean, color = parameter)) +
geom_point(show.legend = FALSE) +
facet_wrap(~parameter, scales = "free_x") +
labs(x = NULL, y = "AUC")

rf_grid <- grid_regular(
mtry(range = c(20, 25)),
min_n(range = c(2, 5)),
levels = 5
)
rf_grid
## # A tibble: 20 x 2
## mtry min_n
## <int> <int>
## 1 20 2
## 2 21 2
## 3 22 2
## 4 23 2
## 5 25 2
## 6 20 3
## 7 21 3
## 8 22 3
## 9 23 3
## 10 25 3
## 11 20 4
## 12 21 4
## 13 22 4
## 14 23 4
## 15 25 4
## 16 20 5
## 17 21 5
## 18 22 5
## 19 23 5
## 20 25 5
set.seed(456)
regular_res <- tune_grid(
tune_wf,
resamples = trees_folds,
grid = rf_grid
)
regular_res
## # Tuning results
## # 5-fold cross-validation
## # A tibble: 5 x 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [1200/301]> Fold1 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
## 2 <split [1201/300]> Fold2 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
## 3 <split [1201/300]> Fold3 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
## 4 <split [1201/300]> Fold4 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
## 5 <split [1201/300]> Fold5 <tibble[,6] [40 x 6]> <tibble[,1] [0 x 1]>
评价模型
regular_res %>%
collect_metrics() %>%
filter(.metric == "roc_auc") %>%
mutate(min_n = factor(min_n)) %>%
ggplot(aes(mtry, mean, color = min_n)) +
geom_line(alpha = 0.5, size = 1.5) +
geom_point() +
labs(y = "AUC")

best_auc <- select_best(regular_res, "roc_auc")
final_rf <- finalize_model(
tune_spec,
best_auc
)
final_rf
## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = 23
## trees = 1000
## min_n = 3
##
## Computational engine: ranger
library(vip)
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
final_rf %>%
set_engine("ranger", importance = "permutation") %>%
fit(legal_status ~ .,
data = juice(tree_prep) %>% select(-tree_id)
) %>%
vip(geom = "point")

final_wf <- tune_wf %>%
finalize_workflow(best_auc)
final_res <- final_wf %>%
last_fit(trees_split)
final_res %>%
collect_metrics()
## # A tibble: 2 x 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.780 Preprocessor1_Model1
## 2 roc_auc binary 0.901 Preprocessor1_Model1
predict(final_res$.workflow[[1]], trees_train[1:10,])
## # A tibble: 10 x 1
## .pred_class
## <fct>
## 1 DPW Maintained
## 2 DPW Maintained
## 3 DPW Maintained
## 4 DPW Maintained
## 5 DPW Maintained
## 6 DPW Maintained
## 7 DPW Maintained
## 8 Other
## 9 Other
## 10 DPW Maintained
final_res1 <- last_fit(final_rf, legal_status ~ ., trees_split)
predict(final_res1$.workflow[[1]], trees_train[1:10,])
## # A tibble: 10 x 1
## .pred_class
## <fct>
## 1 DPW Maintained
## 2 DPW Maintained
## 3 DPW Maintained
## 4 DPW Maintained
## 5 DPW Maintained
## 6 DPW Maintained
## 7 DPW Maintained
## 8 Other
## 9 Other
## 10 DPW Maintained
final_res2 <- fit(final_rf, legal_status ~ ., trees_train)
predict(final_res2, trees_train[1:10,])
## # A tibble: 10 x 1
## .pred_class
## <fct>
## 1 DPW Maintained
## 2 DPW Maintained
## 3 DPW Maintained
## 4 DPW Maintained
## 5 DPW Maintained
## 6 DPW Maintained
## 7 DPW Maintained
## 8 Other
## 9 Other
## 10 DPW Maintained
final_res %>%
collect_predictions() %>%
mutate(correct = case_when(
legal_status == .pred_class ~ "Correct",
TRUE ~ "Incorrect"
)) %>%
bind_cols(trees_test) %>%
ggplot(aes(longitude, latitude, color = correct)) +
geom_point(size = 0.5, alpha = 0.5) +
labs(color = NULL) +
scale_color_manual(values = c("gray80", "darkred"))
## New names:
## * legal_status -> legal_status...6
## * legal_status -> legal_status...10
