获取更多R语言知识,请关注公众号:医学和生信笔记

医学和生信笔记 公众号主要分享:1.医学小知识、肛肠科小知识;2.R语言和Python相关的数据分析、可视化、机器学习等;3.生物信息学学习资料和自己的学习笔记!

本篇主要介绍mlr3包的基本使用。

一个简单的机器学习流程在mlr3中可被分解为以下几个部分:

  • 创建任务
    比如回归、分裂、生存分析、降维、密度任务等等
  • 挑选学习器(算法/模型)
    比如随机森林、决策树、SVM、KNN等等
  • 训练和预测

创建任务

本次示例将使用mtcars数据集创建一个回归任务,结果变量(或者叫因变量等等)是mpg

  1. # 首先加载数据
  2. data("mtcars",package = "datasets")
  3. data <- mtcars[,1:3]
  4. str(data)
  5. ## 'data.frame': 32 obs. of 3 variables:
  6. ## $ mpg : num 21 21 22.8 21.4 18.7 18.1 14.3 24.4 22.8 19.2 ...
  7. ## $ cyl : num 6 6 4 6 8 6 8 4 4 6 ...
  8. ## $ disp: num 160 160 108 258 360 ...

使用as_task_regr()创建回归任务,as_task_classif()可创建分类任务。

  1. library(mlr3)
  2. task_mtcars <- as_task_regr(data,target = "mpg",id="cars") # id是随便起一个名字
  3. print(task_mtcars)
  4. ## <TaskRegr:cars> (32 x 3)
  5. ## * Target: mpg
  6. ## * Properties: -
  7. ## * Features (2):
  8. ## - dbl (2): cyl, disp

可以看到数据以供32行,3列,target是mpg,feature是cyl和disp,都是bdl类型。

在创建模型前先探索数据:

  1. library("mlr3viz") # 使用此包可视化数据
  2. autoplot(task_mtcars, type = "pairs") # 基于GGally,我之前介绍过
  3. ## Registered S3 method overwritten by 'GGally':
  4. ## method from
  5. ## +.gg ggplot2

R语言机器学习mlr3:基础使用 - 图1

如果你觉得每次加载1个R包很烦,可以直接使用library(mlr3verse)加载所有基础包!

如果你想使用自带数据集进行学习,此包也自带了很多流行的机器学习数据集。

查看内置数据集:

  1. as.data.table(mlr_tasks)
  2. ## key task_type nrow ncol properties lgl int dbl chr fct ord pxc
  3. ## 1: boston_housing regr 506 19 0 3 13 0 2 0 0
  4. ## 2: breast_cancer classif 683 10 twoclass 0 0 0 0 0 9 0
  5. ## 3: german_credit classif 1000 21 twoclass 0 3 0 0 14 3 0
  6. ## 4: iris classif 150 5 multiclass 0 0 4 0 0 0 0
  7. ## 5: mtcars regr 32 11 0 0 10 0 0 0 0
  8. ## 6: penguins classif 344 8 multiclass 0 3 2 0 2 0 0
  9. ## 7: pima classif 768 9 twoclass 0 0 8 0 0 0 0
  10. ## 8: sonar classif 208 61 twoclass 0 0 60 0 0 0 0
  11. ## 9: spam classif 4601 58 twoclass 0 0 57 0 0 0 0
  12. ## 10: wine classif 178 14 multiclass 0 2 11 0 0 0 0
  13. ## 11: zoo classif 101 17 multiclass 15 1 0 0 0 0 0

结果很详细,给出了任务类型,行列数,变量类型等。

如果想要使用内置数据集,可使用以下代码:

  1. task_penguin <- tsk("penguins")
  2. print(task_penguin)
  3. ## <TaskClassif:penguins> (344 x 8)
  4. ## * Target: species
  5. ## * Properties: multiclass
  6. ## * Features (7):
  7. ## - int (3): body_mass, flipper_length, year
  8. ## - dbl (2): bill_depth, bill_length
  9. ## - fct (2): island, sex

可以非常方便的取子集查看:

  1. library("mlr3verse")
  2. as.data.table(mlr_tasks)[, 1:4]
  3. ## key task_type nrow ncol
  4. ## 1: actg surv 1151 13
  5. ## 2: bike_sharing regr 17379 14
  6. ## 3: boston_housing regr 506 19
  7. ## 4: breast_cancer classif 683 10
  8. ## 5: faithful dens 272 1
  9. ## 6: gbcs surv 686 10
  10. ## 7: german_credit classif 1000 21
  11. ## 8: grace surv 1000 8
  12. ## 9: ilpd classif 583 11
  13. ## 10: iris classif 150 5
  14. ## 11: kc_housing regr 21613 20
  15. ## 12: lung surv 228 10
  16. ## 13: moneyball regr 1232 15
  17. ## 14: mtcars regr 32 11
  18. ## 15: optdigits classif 5620 65
  19. ## 16: penguins classif 344 8
  20. ## 17: pima classif 768 9
  21. ## 18: precip dens 70 1
  22. ## 19: rats surv 300 5
  23. ## 20: sonar classif 208 61
  24. ## 21: spam classif 4601 58
  25. ## 22: titanic classif 1309 11
  26. ## 23: unemployment surv 3343 6
  27. ## 24: usarrests clust 50 4
  28. ## 25: whas surv 481 11
  29. ## 26: wine classif 178 14
  30. ## 27: zoo classif 101 17
  31. ## key task_type nrow ncol

支持非常多探索数据的操作:

  1. task_penguin$ncol
  2. ## [1] 8
  3. task_penguin$nrow
  4. ## [1] 344
  5. task_penguin$feature_names
  6. ## [1] "bill_depth" "bill_length" "body_mass" "flipper_length"
  7. ## [5] "island" "sex" "year"
  8. task_penguin$feature_types
  9. ## id type
  10. ## 1: bill_depth numeric
  11. ## 2: bill_length numeric
  12. ## 3: body_mass integer
  13. ## 4: flipper_length integer
  14. ## 5: island factor
  15. ## 6: sex factor
  16. ## 7: year integer
  17. task_penguin$target_names
  18. ## [1] "species"
  19. task_penguin$task_type
  20. ## [1] "classif"
  21. task_penguin$data()
  22. ## species bill_depth bill_length body_mass flipper_length island sex
  23. ## 1: Adelie 18.7 39.1 3750 181 Torgersen male
  24. ## 2: Adelie 17.4 39.5 3800 186 Torgersen female
  25. ## 3: Adelie 18.0 40.3 3250 195 Torgersen female
  26. ## 4: Adelie NA NA NA NA Torgersen <NA>
  27. ## 5: Adelie 19.3 36.7 3450 193 Torgersen female
  28. ## ---
  29. ## 340: Chinstrap 19.8 55.8 4000 207 Dream male
  30. ## 341: Chinstrap 18.1 43.5 3400 202 Dream female
  31. ## 342: Chinstrap 18.2 49.6 3775 193 Dream male
  32. ## 343: Chinstrap 19.0 50.8 4100 210 Dream male
  33. ## 344: Chinstrap 18.7 50.2 3775 198 Dream female
  34. ## year
  35. ## 1: 2007
  36. ## 2: 2007
  37. ## 3: 2007
  38. ## 4: 2007
  39. ## 5: 2007
  40. ## ---
  41. ## 340: 2009
  42. ## 341: 2009
  43. ## 342: 2009
  44. ## 343: 2009
  45. ## 344: 2009
  46. task_penguin$head(3)
  47. ## species bill_depth bill_length body_mass flipper_length island sex
  48. ## 1: Adelie 18.7 39.1 3750 181 Torgersen male
  49. ## 2: Adelie 17.4 39.5 3800 186 Torgersen female
  50. ## 3: Adelie 18.0 40.3 3250 195 Torgersen female
  51. ## year
  52. ## 1: 2007
  53. ## 2: 2007
  54. ## 3: 2007
  55. # 还有很多行列选择操作、改变变量的id(比如某个变量不参与模型训练)等多种操作

可视化数据:很多都是基于GGally包,可以看我之前的介绍

  1. autoplot(task_penguin)

R语言机器学习mlr3:基础使用 - 图2

  1. autoplot(task_penguin, type = "pairs")
  2. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  3. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  4. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  5. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  6. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  7. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  8. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  9. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  10. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  11. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  12. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  13. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  14. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  15. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
  16. ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

R语言机器学习mlr3:基础使用 - 图3

  1. autoplot(task_penguin, type = "duo")

R语言机器学习mlr3:基础使用 - 图4

创建learner

所有的学习器都通过以下2个步骤工作:
R语言机器学习mlr3:基础使用 - 图5

mlr3verse只支持常见的学习器,比如随机森林、决策树、SVM、KNN等,如果想要查看所有的学习器,可以安装mlr3extralearners

查看所有的支持的learner:
All learners

  1. # 加载R包,常见的算法
  2. library("mlr3verse")
  3. mlr_learners
  4. ## <DictionaryLearner> with 53 stored values
  5. ## Keys: classif.cv_glmnet, classif.debug, classif.featureless,
  6. ## classif.glmnet, classif.kknn, classif.lda, classif.log_reg,
  7. ## classif.multinom, classif.naive_bayes, classif.nnet, classif.qda,
  8. ## classif.ranger, classif.rpart, classif.svm, classif.xgboost,
  9. ## clust.agnes, clust.ap, clust.cmeans, clust.cobweb, clust.dbscan,
  10. ## clust.diana, clust.em, clust.fanny, clust.featureless, clust.ff,
  11. ## clust.hclust, clust.kkmeans, clust.kmeans, clust.MBatchKMeans,
  12. ## clust.meanshift, clust.pam, clust.SimpleKMeans, clust.xmeans,
  13. ## dens.hist, dens.kde, regr.cv_glmnet, regr.debug, regr.featureless,
  14. ## regr.glmnet, regr.kknn, regr.km, regr.lm, regr.ranger, regr.rpart,
  15. ## regr.svm, regr.xgboost, surv.coxph, surv.cv_glmnet, surv.glmnet,
  16. ## surv.kaplan, surv.ranger, surv.rpart, surv.xgboost

创建learner

  1. # 决策树
  2. learner = lrn("classif.rpart")
  3. print(learner)
  4. ## <LearnerClassifRpart:classif.rpart>
  5. ## * Model: -
  6. ## * Parameters: xval=0
  7. ## * Packages: mlr3, rpart
  8. ## * Predict Type: response
  9. ## * Feature types: logical, integer, numeric, factor, ordered
  10. ## * Properties: importance, missings, multiclass, selected_features,
  11. ## twoclass, weights

查看支持的超参数

  1. learner$param_set
  2. ## <ParamSet>
  3. ## id class lower upper nlevels default value
  4. ## 1: cp ParamDbl 0 1 Inf 0.01
  5. ## 2: keep_model ParamLgl NA NA 2 FALSE
  6. ## 3: maxcompete ParamInt 0 Inf Inf 4
  7. ## 4: maxdepth ParamInt 1 30 30 30
  8. ## 5: maxsurrogate ParamInt 0 Inf Inf 5
  9. ## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
  10. ## 7: minsplit ParamInt 1 Inf Inf 20
  11. ## 8: surrogatestyle ParamInt 0 1 2 0
  12. ## 9: usesurrogate ParamInt 0 2 3 2
  13. ## 10: xval ParamInt 0 Inf Inf 10 0

一目了然,方便使用,记不住了可以看看,毕竟太多了,这一点比tidymodels贴心。

设定超参数的值

  1. learner$param_set$values = list(cp = 0.01, xval = 0)
  2. learner
  3. ## <LearnerClassifRpart:classif.rpart>
  4. ## * Model: -
  5. ## * Parameters: cp=0.01, xval=0
  6. ## * Packages: mlr3, rpart
  7. ## * Predict Type: response
  8. ## * Feature types: logical, integer, numeric, factor, ordered
  9. ## * Properties: importance, missings, multiclass, selected_features,
  10. ## twoclass, weights

也可以在指定learner时设定

  1. learner = lrn("classif.rpart", xval=0, cp = 0.001)
  2. learner$param_set$values
  3. ## $xval
  4. ## [1] 0
  5. ##
  6. ## $cp
  7. ## [1] 0.001

训练、预测和性能评价

创建任务,选择模型

  1. library("mlr3verse")
  2. task = tsk("penguins") # 使用内置数据集
  3. learner = lrn("classif.rpart") #决策树分类

划分训练集和测试集

  1. spilt <- partition(task,ratio = 0.6, stratify = T)
  2. spilt$train
  3. ## [1] 2 3 4 5 7 8 10 11 12 14 15 16 17 19 23 25 26 27
  4. ## [19] 28 30 31 33 34 36 37 40 42 45 46 48 50 51 53 56 59 60
  5. ## [37] 61 62 64 66 67 68 69 71 73 75 78 82 83 84 88 89 91 94
  6. ## [55] 96 97 99 100 101 102 104 107 108 113 114 115 118 120 121 123 125 126
  7. ## [73] 127 128 129 130 131 132 133 135 136 137 138 139 142 143 145 149 150 151
  8. ## [91] 152 154 156 157 159 160 163 169 170 171 172 173 175 176 179 180 181 182
  9. ## [109] 183 186 187 188 189 193 194 197 199 200 201 203 206 208 210 211 212 213
  10. ## [127] 214 215 216 218 219 220 222 223 224 225 226 228 229 230 233 236 237 239
  11. ## [145] 240 241 242 243 247 248 249 252 253 254 255 256 257 259 260 262 266 271
  12. ## [163] 272 273 274 277 279 280 285 288 290 291 293 294 295 296 297 299 300 301
  13. ## [181] 302 304 305 306 309 310 312 313 317 319 321 322 323 324 325 328 330 331
  14. ## [199] 332 334 337 338 339 340 341 342

训练模型

  1. learner$train(task, row_ids = spilt$train)
  2. print(learner$model)
  3. ## n= 206
  4. ##
  5. ## node), split, n, loss, yval, (yprob)
  6. ## * denotes terminal node
  7. ##
  8. ## 1) root 206 115 Adelie (0.44174757 0.19902913 0.35922330)
  9. ## 2) flipper_length< 207.5 128 39 Adelie (0.69531250 0.30468750 0.00000000)
  10. ## 4) bill_length< 42.35 86 0 Adelie (1.00000000 0.00000000 0.00000000) *
  11. ## 5) bill_length>=42.35 42 3 Chinstrap (0.07142857 0.92857143 0.00000000) *
  12. ## 3) flipper_length>=207.5 78 4 Gentoo (0.02564103 0.02564103 0.94871795) *

预测

  1. prediction <- learner$predict(task, row_ids = spilt$test)
  2. print(prediction)
  3. ## <PredictionClassif> for 138 observations:
  4. ## row_ids truth response
  5. ## 1 Adelie Adelie
  6. ## 6 Adelie Adelie
  7. ## 9 Adelie Adelie
  8. ## ---
  9. ## 336 Chinstrap Chinstrap
  10. ## 343 Chinstrap Gentoo
  11. ## 344 Chinstrap Chinstrap

混淆矩阵

  1. prediction$confusion
  2. ## truth
  3. ## response Adelie Chinstrap Gentoo
  4. ## Adelie 53 1 0
  5. ## Chinstrap 8 24 2
  6. ## Gentoo 0 2 48

可视化

  1. autoplot(prediction)

R语言机器学习mlr3:基础使用 - 图6

模型评价

先查看下支持的评价指标

  1. mlr_measures
  2. ## <DictionaryMeasure> with 87 stored values
  3. ## Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
  4. ## classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
  5. ## classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
  6. ## classif.logloss, classif.mbrier, classif.mcc, classif.npv,
  7. ## classif.ppv, classif.prauc, classif.precision, classif.recall,
  8. ## classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
  9. ## classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
  10. ## clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
  11. ## regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
  12. ## regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
  13. ## regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
  14. ## regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi,
  15. ## surv.brier, surv.calib_alpha, surv.calib_beta, surv.chambless_auc,
  16. ## surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
  17. ## surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2,
  18. ## surv.rmse, surv.schmid, surv.song_auc, surv.song_tnr, surv.song_tpr,
  19. ## surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2, time_both,
  20. ## time_predict, time_train

这里我们选择accuracy

  1. measure <- msr("classif.acc")
  2. prediction$score(measure)
  3. ## classif.acc
  4. ## 0.9057971

选择多个指标:

  1. measures <- msrs(c("classif.acc","classif.auc","classif.ce"))
  2. prediction$score(measures)
  3. ## classif.acc classif.auc classif.ce
  4. ## 0.9057971 NaN 0.0942029

对于简单的机器学习任务来说,mlr3真的是太方便了!4行代码即可搞定一个基本的流程!

获取更多R语言知识,请关注公众号:医学和生信笔记

医学和生信笔记 公众号主要分享:1.医学小知识、肛肠科小知识;2.R语言和Python相关的数据分析、可视化、机器学习等;3.生物信息学学习资料和自己的学习笔记!