获取更多R语言知识,请关注公众号:医学和生信笔记

医学和生信笔记 公众号主要分享:1.医学小知识、肛肠科小知识;2.R语言和Python相关的数据分析、可视化、机器学习等;3.生物信息学学习资料和自己的学习笔记!

模型调优

当你对你的模型表现不满意时,你可能希望调高你的模型表现,可通过超参数调整或者尝试一个更加适合你的模型,本篇将介绍这些操作。

本章主要包括3个部分的内容:

超参数调整

机器学习模型都有默认的超参数,但是这些超参数不能根据数据自动调整,往往不能得到更好的性能表现。但是手动调整往往也不能获得最佳的表现,mlr3包含自动调参的策略,在此包中实现自动调参,需要指定:搜索空间(search_space)优化算法(调参方法)评估方法(重抽样策略)评价指标

特征选择

主要是通过mlr3filtermlr3select包进行。

嵌套重抽样

调整超参数

很多人戏称调参的过程就像是”炼丹”!确实差不多,而且很多时候你调整后的结果可能还不如默认的结果好!这就好比打游戏,”一顿操作猛如虎,一看战绩0比5”!

模型调优一定要基于对算法和数据的理解进行,不是随便调的。

我们使用著名的糖尿病数据集进行演示,首先创建任务

  1. library(mlr3verse)
  2. ## 载入需要的程辑包:mlr3
  3. task <- tsk("pima")
  4. print(task)
  5. ## <TaskClassif:pima> (768 x 9)
  6. ## * Target: diabetes
  7. ## * Properties: twoclass
  8. ## * Features (8):
  9. ## - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
  10. ## triceps

选择算法,查看算法支持的超参数

  1. learner <- lrn("classif.rpart")
  2. learner$param_set
  3. ## <ParamSet>
  4. ## id class lower upper nlevels default value
  5. ## 1: cp ParamDbl 0 1 Inf 0.01
  6. ## 2: keep_model ParamLgl NA NA 2 FALSE
  7. ## 3: maxcompete ParamInt 0 Inf Inf 4
  8. ## 4: maxdepth ParamInt 1 30 30 30
  9. ## 5: maxsurrogate ParamInt 0 Inf Inf 5
  10. ## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
  11. ## 7: minsplit ParamInt 1 Inf Inf 20
  12. ## 8: surrogatestyle ParamInt 0 1 2 0
  13. ## 9: usesurrogate ParamInt 0 2 3 2
  14. ## 10: xval ParamInt 0 Inf Inf 10 0

在这里我们选择调整复杂度参数cp和最小分支参数minsplit,并设定超参数的调整范围:

  1. search_space <- ps(
  2. cp = p_dbl(lower = 0.001, upper = 0.1),
  3. minsplit = p_int(lower = 1, upper = 10)
  4. )
  5. search_space
  6. ## <ParamSet>
  7. ## id class lower upper nlevels default value
  8. ## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
  9. ## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>

然后选择重抽样方法和性能指标

  1. hout <- rsmp("holdout", ratio = 0.7)
  2. measure <- msr("classif.ce")

接下来进行调参有两种方法。

方法一:通过tuninginstancesinglecritetuner训练模型

  1. library(mlr3tuning)
  2. ## 载入需要的程辑包:paradox
  3. evals20 <- trm("evals", n_evals = 20) # 设定何时停止训练
  4. # 统一放入instance中
  5. instance <- TuningInstanceSingleCrit$new(
  6. task = task,
  7. learner = learner,
  8. resampling = hout,
  9. measure = measure,
  10. terminator = evals20,
  11. search_space = search_space
  12. )
  13. instance
  14. ## <TuningInstanceSingleCrit>
  15. ## * State: Not optimized
  16. ## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
  17. ## * Search Space:
  18. ## <ParamSet>
  19. ## id class lower upper nlevels default value
  20. ## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
  21. ## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>
  22. ## * Terminator: <TerminatorEvals>
  23. ## * Terminated: FALSE
  24. ## * Archive:
  25. ## <ArchiveTuning>
  26. ## Null data.table (0 rows and 0 cols)

关于何时停止训练,mlr3给出了5种方法:

  • Terminate after a given time:一定时间后停止
  • Terninate after a given number of iterations:特定迭代次数后停止
  • Terminate after a specific performance has been reached:达到特定性能指标后停止
  • Terminate when tuning dose find a better configuration for a given number of iterations:在给定迭代次数中确实找到表现很好的参数组合后停止
  • A combination of above in ALL or ANY fashon:上面几种方法组合

然后还需要设置超参数搜索的方法:

mlr3tuning目前支持以下超参数搜索的方法:

  • Grid search:网格搜索
  • Random search:随机搜索
  • Generalized simulated annealing
  • Non-Linear optimization
  1. # 这里选择网格搜索
  2. tuner <- tnr("grid_search", resolution = 5) # 网格搜索

接下来就是进行训练模型,上面我们设置了网格搜索的分辨率是5,我们有2个超参数需要调整,所以理论上一共有5 * 5 = 25个组合,但是在前面的停止搜索的方法中我们选择了n_evals = 20,所有实际上在评价完20个组合后就会停止了!

  1. #lgr::get_logger("mlr3")$set_threshold("warn")
  2. #lgr::get_logger("bbotk")$set_threshold("warn") # 减少屏幕打印内容
  3. tuner$optimize(instance)
  4. ## INFO [20:51:28.312] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]'
  5. ## INFO [20:51:28.331] [bbotk] Evaluating 1 configuration(s)
  6. ## INFO [20:51:29.309] [bbotk] Finished optimizing after 20 evaluation(s)
  7. ## INFO [20:51:29.310] [bbotk] Result:
  8. ## INFO [20:51:29.310] [bbotk] cp minsplit learner_param_vals x_domain classif.ce
  9. ## INFO [20:51:29.310] [bbotk] 0.02575 3 <list[3]> <list[2]> 0.2130435
  10. ## cp minsplit learner_param_vals x_domain classif.ce
  11. ## 1: 0.02575 3 <list[3]> <list[2]> 0.2130435

查看调整好的超参数:

  1. instance$result_learner_param_vals
  2. ## $xval
  3. ## [1] 0
  4. ##
  5. ## $cp
  6. ## [1] 0.02575
  7. ##
  8. ## $minsplit
  9. ## [1] 3

查看模型性能:

  1. instance$result_y
  2. ## classif.ce
  3. ## 0.2130435

查看每一次迭代的结果,只有20个:

  1. instance$archive
  2. ## <ArchiveTuning>
  3. ## cp minsplit classif.ce runtime_learners timestamp batch_nr
  4. ## 1: 0.026 3 0.21 0.02 2022-02-27 20:51:28 1
  5. ## 2: 0.075 8 0.21 0.00 2022-02-27 20:51:28 2
  6. ## 3: 0.050 5 0.21 0.00 2022-02-27 20:51:28 3
  7. ## 4: 0.001 1 0.30 0.00 2022-02-27 20:51:28 4
  8. ## 5: 0.100 3 0.21 0.02 2022-02-27 20:51:28 5
  9. ## 6: 0.026 5 0.21 0.02 2022-02-27 20:51:28 6
  10. ## 7: 0.100 8 0.21 0.01 2022-02-27 20:51:28 7
  11. ## 8: 0.001 8 0.27 0.00 2022-02-27 20:51:28 8
  12. ## 9: 0.001 5 0.28 0.00 2022-02-27 20:51:28 9
  13. ## 10: 0.100 5 0.21 0.02 2022-02-27 20:51:28 10
  14. ## 11: 0.075 10 0.21 0.00 2022-02-27 20:51:28 11
  15. ## 12: 0.050 10 0.21 0.01 2022-02-27 20:51:28 12
  16. ## 13: 0.075 5 0.21 0.00 2022-02-27 20:51:28 13
  17. ## 14: 0.050 8 0.21 0.01 2022-02-27 20:51:29 14
  18. ## 15: 0.001 10 0.26 0.00 2022-02-27 20:51:29 15
  19. ## 16: 0.050 3 0.21 0.00 2022-02-27 20:51:29 16
  20. ## 17: 0.050 1 0.21 0.02 2022-02-27 20:51:29 17
  21. ## 18: 0.100 10 0.21 0.00 2022-02-27 20:51:29 18
  22. ## 19: 0.075 1 0.21 0.01 2022-02-27 20:51:29 19
  23. ## 20: 0.026 1 0.21 0.00 2022-02-27 20:51:29 20
  24. ## warnings errors resample_result
  25. ## 1: 0 0 <ResampleResult[22]>
  26. ## 2: 0 0 <ResampleResult[22]>
  27. ## 3: 0 0 <ResampleResult[22]>
  28. ## 4: 0 0 <ResampleResult[22]>
  29. ## 5: 0 0 <ResampleResult[22]>
  30. ## 6: 0 0 <ResampleResult[22]>
  31. ## 7: 0 0 <ResampleResult[22]>
  32. ## 8: 0 0 <ResampleResult[22]>
  33. ## 9: 0 0 <ResampleResult[22]>
  34. ## 10: 0 0 <ResampleResult[22]>
  35. ## 11: 0 0 <ResampleResult[22]>
  36. ## 12: 0 0 <ResampleResult[22]>
  37. ## 13: 0 0 <ResampleResult[22]>
  38. ## 14: 0 0 <ResampleResult[22]>
  39. ## 15: 0 0 <ResampleResult[22]>
  40. ## 16: 0 0 <ResampleResult[22]>
  41. ## 17: 0 0 <ResampleResult[22]>
  42. ## 18: 0 0 <ResampleResult[22]>
  43. ## 19: 0 0 <ResampleResult[22]>
  44. ## 20: 0 0 <ResampleResult[22]>

接下来就可以把训练好的超参数应用于模型,重新应用于数据:

  1. learner$param_set$values <- instance$result_learner_param_vals
  2. learner$train(task)

这个训练好的模型就可以用于预测了,使用learner$predict()即可!

以上步骤写起来有些复杂,与tidymodels相比不够简洁好理解,我刚开始学习的时候经常记不住,后来版本更新后终于有了简便写法:

  1. instance <- tune(
  2. task = task,
  3. learner = learner,
  4. resampling = hout,
  5. measure = measure,
  6. search_space = search_space,
  7. method = "grid_search",
  8. resolution = 5,
  9. term_evals = 25
  10. )
  11. ## INFO [20:51:29.402] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=25, k=0]'
  12. ## INFO [20:51:29.403] [bbotk] Evaluating 1 configuration(s)
  13. ## INFO [20:51:30.534] [bbotk] Finished optimizing after 25 evaluation(s)
  14. ## INFO [20:51:30.534] [bbotk] Result:
  15. ## INFO [20:51:30.535] [bbotk] cp minsplit learner_param_vals x_domain classif.ce
  16. ## INFO [20:51:30.535] [bbotk] 0.02575 10 <list[3]> <list[2]> 0.2347826
  17. instance$result_learner_param_vals
  18. ## $xval
  19. ## [1] 0
  20. ##
  21. ## $cp
  22. ## [1] 0.02575
  23. ##
  24. ## $minsplit
  25. ## [1] 10
  26. instance$result_y
  27. ## classif.ce
  28. ## 0.2347826
  29. learner$param_set$values <- instance$result_learner_param_vals
  30. learner$train(task)

mlr3也支持同时设定多个性能指标:

  1. measures <- msrs(c("classif.ce","time_train")) # 设定多个评价指标
  2. evals20 <- trm("evals", n_evals = 20)
  3. instance <- TuningInstanceMultiCrit$new(
  4. task = task,
  5. learner = learner,
  6. resampling = hout,
  7. measures = measures,
  8. search_space = search_space,
  9. terminator = evals20
  10. )
  11. tuner$optimize(instance)
  12. ## INFO [20:51:30.595] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]'
  13. ## INFO [20:51:30.597] [bbotk] Evaluating 1 configuration(s)
  14. ## INFO [20:51:30.605] [mlr3] Running benchmark with 1 resampling iterations
  15. ## INFO [20:51:30.608] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
  16. ## INFO [20:51:30.620] [mlr3] Finished benchmark
  17. ## INFO [20:51:30.642] [bbotk] Result of batch 1:
  18. ## INFO [20:51:30.643] [bbotk] cp minsplit classif.ce time_train warnings errors runtime_learners
  19. ## INFO [20:51:30.643] [bbotk] 0.0505 1 0.2347826 0 0 0 0.02
  20. ## cp minsplit learner_param_vals x_domain classif.ce time_train
  21. ## 1: 0.05050 1 <list[3]> <list[2]> 0.2347826 0
  22. ## 2: 0.07525 1 <list[3]> <list[2]> 0.2347826 0
  23. ## 3: 0.07525 10 <list[3]> <list[2]> 0.2347826 0
  24. ## 4: 0.10000 8 <list[3]> <list[2]> 0.2347826 0
  25. ## 5: 0.02575 3 <list[3]> <list[2]> 0.2347826 0
  26. ## 6: 0.07525 8 <list[3]> <list[2]> 0.2347826 0
  27. ## 7: 0.10000 3 <list[3]> <list[2]> 0.2347826 0
  28. ## 8: 0.10000 5 <list[3]> <list[2]> 0.2347826 0
  29. ## 9: 0.02575 5 <list[3]> <list[2]> 0.2347826 0
  30. ## 10: 0.07525 5 <list[3]> <list[2]> 0.2347826 0
  31. ## 11: 0.05050 8 <list[3]> <list[2]> 0.2347826 0
  32. ## 12: 0.05050 3 <list[3]> <list[2]> 0.2347826 0
  33. ## 13: 0.07525 3 <list[3]> <list[2]> 0.2347826 0
  34. ## 14: 0.05050 5 <list[3]> <list[2]> 0.2347826 0
  35. ## 15: 0.02575 1 <list[3]> <list[2]> 0.2347826 0

查看结果:

  1. instance$result_learner_param_vals
  2. ## [[1]]
  3. ## [[1]]$xval
  4. ## [1] 0
  5. ##
  6. ## [[1]]$cp
  7. ## [1] 0.0505
  8. ##
  9. ## [[1]]$minsplit
  10. ## [1] 1
  11. ##
  12. ##
  13. ## [[2]]
  14. ## [[2]]$xval
  15. ## [1] 0
  16. ##
  17. ## [[2]]$cp
  18. ## [1] 0.07525
  19. ##
  20. ## [[2]]$minsplit
  21. ## [1] 1
  22. ##
  23. ##
  24. ## [[3]]
  25. ## [[3]]$xval
  26. ## [1] 0
  27. ##
  28. ## [[3]]$cp
  29. ## [1] 0.07525
  30. ##
  31. ## [[3]]$minsplit
  32. ## [1] 10
  33. ##
  34. ##
  35. ## [[4]]
  36. ## [[4]]$xval
  37. ## [1] 0
  38. ##
  39. ## [[4]]$cp
  40. ## [1] 0.1
  41. ##
  42. ## [[4]]$minsplit
  43. ## [1] 8
  44. ##
  45. ##
  46. ## [[5]]
  47. ## [[5]]$xval
  48. ## [1] 0
  49. ##
  50. ## [[5]]$cp
  51. ## [1] 0.02575
  52. ##
  53. ## [[5]]$minsplit
  54. ## [1] 3
  55. ##
  56. ##
  57. ## [[6]]
  58. ## [[6]]$xval
  59. ## [1] 0
  60. ##
  61. ## [[6]]$cp
  62. ## [1] 0.07525
  63. ##
  64. ## [[6]]$minsplit
  65. ## [1] 8
  66. ##
  67. ##
  68. ## [[7]]
  69. ## [[7]]$xval
  70. ## [1] 0
  71. ##
  72. ## [[7]]$cp
  73. ## [1] 0.1
  74. ##
  75. ## [[7]]$minsplit
  76. ## [1] 3
  77. ##
  78. ##
  79. ## [[8]]
  80. ## [[8]]$xval
  81. ## [1] 0
  82. ##
  83. ## [[8]]$cp
  84. ## [1] 0.1
  85. ##
  86. ## [[8]]$minsplit
  87. ## [1] 5
  88. ##
  89. ##
  90. ## [[9]]
  91. ## [[9]]$xval
  92. ## [1] 0
  93. ##
  94. ## [[9]]$cp
  95. ## [1] 0.02575
  96. ##
  97. ## [[9]]$minsplit
  98. ## [1] 5
  99. ##
  100. ##
  101. ## [[10]]
  102. ## [[10]]$xval
  103. ## [1] 0
  104. ##
  105. ## [[10]]$cp
  106. ## [1] 0.07525
  107. ##
  108. ## [[10]]$minsplit
  109. ## [1] 5
  110. ##
  111. ##
  112. ## [[11]]
  113. ## [[11]]$xval
  114. ## [1] 0
  115. ##
  116. ## [[11]]$cp
  117. ## [1] 0.0505
  118. ##
  119. ## [[11]]$minsplit
  120. ## [1] 8
  121. ##
  122. ##
  123. ## [[12]]
  124. ## [[12]]$xval
  125. ## [1] 0
  126. ##
  127. ## [[12]]$cp
  128. ## [1] 0.0505
  129. ##
  130. ## [[12]]$minsplit
  131. ## [1] 3
  132. ##
  133. ##
  134. ## [[13]]
  135. ## [[13]]$xval
  136. ## [1] 0
  137. ##
  138. ## [[13]]$cp
  139. ## [1] 0.07525
  140. ##
  141. ## [[13]]$minsplit
  142. ## [1] 3
  143. ##
  144. ##
  145. ## [[14]]
  146. ## [[14]]$xval
  147. ## [1] 0
  148. ##
  149. ## [[14]]$cp
  150. ## [1] 0.0505
  151. ##
  152. ## [[14]]$minsplit
  153. ## [1] 5
  154. ##
  155. ##
  156. ## [[15]]
  157. ## [[15]]$xval
  158. ## [1] 0
  159. ##
  160. ## [[15]]$cp
  161. ## [1] 0.02575
  162. ##
  163. ## [[15]]$minsplit
  164. ## [1] 1
  165. instance$rusult_y
  166. ## NULL

以上就是第一种方法,接下来介绍第二种方法。

方法二:通过autotuner训练模型

这种方式方法把调整参数、将调整好的参数应用于模型放到一起了,但是也需要提前设定好各种需要的参数。

  1. task <- tsk("pima") # 创建任务
  2. leanrer <- lrn("classif.rpart") # 选择学习器
  3. search_space <- ps(
  4. cp = p_dbl(0.001, 0.1),
  5. minsplit = p_int(1,10)
  6. ) # 设定搜索范围
  7. terminator <- trm("evals", n_evals = 10) # 设定停止标志
  8. tuner <- tnr("random_search") # 选择搜索方法
  9. resampling <- rsmp("holdout") # 选择重抽样方法
  10. measure <- msr("classif.acc") # 选择评价指标
  11. # 训练
  12. at <- AutoTuner$new(
  13. learner = learner,
  14. resampling = resampling,
  15. search_space = search_space,
  16. measure = measure,
  17. tuner = tuner,
  18. terminator = terminator
  19. )

自动选择最优参数并作用于数据:

  1. at$train(task)
  2. ## INFO [20:51:31.873] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]'
  3. ## INFO [20:51:32.332] [bbotk] 0.02278977 3 <list[3]> <list[2]> 0.7695312
  4. at$predict(task)
  5. ## <PredictionClassif> for 768 observations:
  6. ## row_ids truth response
  7. ## 1 pos pos
  8. ## 2 neg neg
  9. ## 3 pos neg
  10. ## ---
  11. ## 766 neg neg
  12. ## 767 pos neg
  13. ## 768 neg neg

这个方法也有个简便写法:

  1. auto_learner <- auto_tuner(
  2. learner = learner,
  3. resampling = resampling,
  4. measure = measure,
  5. search_space = search_space,
  6. method = "random_search",
  7. term_evals = 10
  8. )
  9. auto_learner$train(task)
  10. ## INFO [20:51:32.407] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]'
  11. ## INFO [20:51:32.858] [bbotk] Finished optimizing after 10 evaluation(s)
  12. ## INFO [20:51:32.859] [bbotk] Result:
  13. ## INFO [20:51:32.859] [bbotk] cp minsplit learner_param_vals x_domain classif.acc
  14. ## INFO [20:51:32.859] [bbotk] 0.02922122 8 <list[3]> <list[2]> 0.7539062
  15. auto_learner$predict(task)
  16. ## <PredictionClassif> for 768 observations:
  17. ## row_ids truth response
  18. ## 1 pos pos
  19. ## 2 neg neg
  20. ## 3 pos neg
  21. ## ---
  22. ## 766 neg neg
  23. ## 767 pos neg
  24. ## 768 neg neg

超参数设定的方法

每次单独设置超参数的范围等可能会显得比较笨重无聊,mlr3也提供另外一种可以在选择学习器时进行设定超参数的方法。

  1. # 在选择学习器时设置超参数范围
  2. learner <- lrn("classif.svm")
  3. learner$param_set$values$kernel <- "polynomial"
  4. learner$param_set$values$degree <- to_tune(lower = 1, upper = 3)
  5. print(learner$param_set$search_space())
  6. ## <ParamSet>
  7. ## id class lower upper nlevels default value
  8. ## 1: degree ParamInt 1 3 3 <NoDefault[3]>

但其实这样也有问题,这个方法要求你对算法很熟悉,能够记住所有超参数记忆它们在mlr3中的拼写!但很显然这有点困难,所有我还是推荐第一种,每次单独设置,记不住还可以查看一下具体的超参数。

参数依赖

某些超参数只有在某些条件下才有效,比如支持向量机(SVM),它的degree参数只有在kernelpolynomial时才有效,这种情况也可以在mlr3中设置好。

  1. library(data.table)
  2. search_space = ps(
  3. cost = p_dbl(-1, 1, trafo = function(x) 10^x), # 可进行数据变换
  4. kernel = p_fct(c("polynomial", "radial")),
  5. degree = p_int(1, 3, depends = kernel == "polynomial") # 设置参数依赖
  6. )
  7. rbindlist(generate_design_grid(search_space, 3)$transpose(), fill = TRUE)
  8. ## cost kernel degree
  9. ## 1: 0.1 polynomial 1
  10. ## 2: 0.1 polynomial 2
  11. ## 3: 0.1 polynomial 3
  12. ## 4: 0.1 radial NA
  13. ## 5: 1.0 polynomial 1
  14. ## 6: 1.0 polynomial 2
  15. ## 7: 1.0 polynomial 3
  16. ## 8: 1.0 radial NA
  17. ## 9: 10.0 polynomial 1
  18. ## 10: 10.0 polynomial 2
  19. ## 11: 10.0 polynomial 3
  20. ## 12: 10.0 radial NA

进行以上设置后在进行后面的操作时不会出错,自动处理。

嵌套重抽样

既有外部重抽样,也有内部重抽样,彼此嵌套,可以很好的解决过拟合问题,得到更加稳定的模型。

对于概念不清楚的可以自行百度学习,就不在这里赘述了。

可使用下图帮助理解: R语言机器学习mlr3:模型调优 - 图1

进行嵌套重抽样

内部使用4折交叉验证:

  1. rm(list = ls())
  2. library(mlr3verse)
  3. library(mlr3tuning)
  4. learner <- lrn("classif.rpart")
  5. resampling <- rsmp("cv", folds = 4)
  6. measure <- msr("classif.ce")
  7. search_space <- ps(cp = p_dbl(lower = 0.001, upper = 0.1))
  8. terminator <- trm("evals", n_evals = 5)
  9. tuner <- tnr("grid_search", resolution = 10)
  10. at <- AutoTuner$new(learner, resampling, measure, terminator, tuner,search_space)

外部使用3折交叉验证:

  1. task <- tsk("pima")
  2. outer_resampling <- rsmp("cv", folds = 3)
  3. rr <- resample(task, at, outer_resampling, store_models = T)
  4. ## INFO [20:51:33.072] [mlr3] Applying learner 'classif.rpart.tuned' on task 'pima' (iter 3/3)
  5. ## INFO [20:51:34.391] [mlr3] Finished benchmark
  6. ## INFO [20:51:34.411] [bbotk] Result of batch 5:
  7. ## INFO [20:51:34.412] [bbotk] cp classif.ce warnings errors runtime_learners
  8. ## INFO [20:51:34.412] [bbotk] 0.012 0.2382812 0 0 0.02
  9. ## INFO [20:51:34.412] [bbotk] uhash
  10. ## INFO [20:51:34.412] [bbotk] 375b973a-a946-4c77-94f7-451b29f07cb6
  11. ## INFO [20:51:34.415] [bbotk] Finished optimizing after 5 evaluation(s)
  12. ## INFO [20:51:34.415] [bbotk] Result:
  13. ## INFO [20:51:34.416] [bbotk] cp learner_param_vals x_domain classif.ce
  14. ## INFO [20:51:34.416] [bbotk] 0.023 <list[2]> <list[1]> 0.2382812

这里演示的数据集比较小,大数据可以使用并行化技术,将在后面介绍。

评价模型

提取内部抽样的模型表现:

  1. extract_inner_tuning_results(rr)
  2. ## iteration cp classif.ce learner_param_vals x_domain task_id
  3. ## 1: 1 0.078 0.2812500 <list[2]> <list[1]> pima
  4. ## 2: 2 0.023 0.2382812 <list[2]> <list[1]> pima
  5. ## 3: 3 0.023 0.2480469 <list[2]> <list[1]> pima
  6. ## learner_id resampling_id
  7. ## 1: classif.rpart.tuned cv
  8. ## 2: classif.rpart.tuned cv
  9. ## 3: classif.rpart.tuned cv

提取内部抽样的存档:

  1. extract_inner_tuning_archives(rr)
  2. ## iteration cp classif.ce x_domain_cp runtime_learners timestamp
  3. ## 1: 1 0.078 0.2812500 0.078 0.03 2022-02-27 20:51:33
  4. ## 2: 1 0.067 0.2871094 0.067 0.03 2022-02-27 20:51:33
  5. ## 3: 1 0.100 0.2812500 0.100 0.02 2022-02-27 20:51:33
  6. ## 4: 1 0.089 0.2812500 0.089 0.03 2022-02-27 20:51:33
  7. ## 5: 1 0.023 0.2949219 0.023 0.04 2022-02-27 20:51:33
  8. ## 6: 2 0.023 0.2382812 0.023 0.02 2022-02-27 20:51:34
  9. ## 7: 2 0.089 0.2617188 0.089 0.02 2022-02-27 20:51:34
  10. ## 8: 2 0.078 0.2617188 0.078 0.03 2022-02-27 20:51:34
  11. ## 9: 2 0.034 0.2421875 0.034 0.01 2022-02-27 20:51:34
  12. ## 10: 2 0.012 0.2382812 0.012 0.02 2022-02-27 20:51:34
  13. ## 11: 3 0.012 0.2519531 0.012 0.04 2022-02-27 20:51:33
  14. ## 12: 3 0.089 0.2636719 0.089 0.03 2022-02-27 20:51:33
  15. ## 13: 3 0.067 0.2519531 0.067 0.02 2022-02-27 20:51:33
  16. ## 14: 3 0.023 0.2480469 0.023 0.04 2022-02-27 20:51:33
  17. ## 15: 3 0.078 0.2636719 0.078 0.04 2022-02-27 20:51:33
  18. ## batch_nr warnings errors resample_result task_id learner_id
  19. ## 1: 1 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  20. ## 2: 2 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  21. ## 3: 3 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  22. ## 4: 4 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  23. ## 5: 5 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  24. ## 6: 1 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  25. ## 7: 2 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  26. ## 8: 3 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  27. ## 9: 4 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  28. ## 10: 5 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  29. ## 11: 1 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  30. ## 12: 2 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  31. ## 13: 3 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  32. ## 14: 4 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  33. ## 15: 5 0 0 <ResampleResult[22]> pima classif.rpart.tuned
  34. ## resampling_id
  35. ## 1: cv
  36. ## 2: cv
  37. ## 3: cv
  38. ## 4: cv
  39. ## 5: cv
  40. ## 6: cv
  41. ## 7: cv
  42. ## 8: cv
  43. ## 9: cv
  44. ## 10: cv
  45. ## 11: cv
  46. ## 12: cv
  47. ## 13: cv
  48. ## 14: cv
  49. ## 15: cv

可以看到和上面的结果是不一样的哦,每一折都有5次迭代,这就和我们设置的参数有关系了。

查看外部重抽样的模型表现:

  1. rr$score()[,9]
  2. ## classif.ce
  3. ## 1: 0.2460938
  4. ## 2: 0.2656250
  5. ## 3: 0.2890625

查看平均表现:

rr$aggregate()
## classif.ce 
##  0.2669271

把超参数应用于模型

at$train(task)
## INFO  [20:51:34.578] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=5, k=0]' 
## INFO  [20:51:34.970] [mlr3] Finished benchmark 
## INFO  [20:51:34.990] [bbotk] Result of batch 5: 
## INFO  [20:51:34.991] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [20:51:34.991] [bbotk]  0.012  0.2434896        0      0             0.04 
## INFO  [20:51:34.991] [bbotk]                                 uhash 
## INFO  [20:51:34.991] [bbotk]  fdca679b-4117-4d26-974c-26509cba1d9d 
## INFO  [20:51:34.993] [bbotk] Finished optimizing after 5 evaluation(s) 
## INFO  [20:51:34.994] [bbotk] Result: 
## INFO  [20:51:34.994] [bbotk]     cp learner_param_vals  x_domain classif.ce 
## INFO  [20:51:34.994] [bbotk]  0.012          <list[2]> <list[1]>  0.2434896

现在模型就可以应用于新的数据集了。

以上过程也是有简便写法的,但是需要注意,这里的mlr3tuning需要用github版的,cran版的还有bug,不知道修复了没:

rr1 <- tune_nested(
  method = "grid_search",
  resolution = 10,
  task = task,
  learner = learner,
  inner_resampling = resampling,
  outer_resampling = outer_resampling,
  measure = measure,
  term_evals = 20,
  search_space = search_space
  )
## INFO  [20:51:35.045] [mlr3] Applying learner 'classif.rpart.tuned' on task 'pima' (iter 1/3) 
## INFO  [20:51:35.067] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]' 
## INFO  [20:51:37.665] [mlr3] Finished benchmark 
## INFO  [20:51:37.684] [bbotk] Result of batch 10: 
## INFO  [20:51:37.685] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [20:51:37.685] [bbotk]  0.056  0.2441406        0      0             0.02 
## INFO  [20:51:37.685] [bbotk]                                 uhash 
## INFO  [20:51:37.685] [bbotk]  a289e821-a615-414e-a68f-ba66ed39508b 
## INFO  [20:51:37.688] [bbotk] Finished optimizing after 10 evaluation(s) 
## INFO  [20:51:37.688] [bbotk] Result: 
## INFO  [20:51:37.689] [bbotk]     cp learner_param_vals  x_domain classif.ce 
## INFO  [20:51:37.689] [bbotk]  0.089          <list[2]> <list[1]>  0.2441406

这个rr1本质上和rr是一样的,

print(rr1)
## <ResampleResult> of 3 iterations
## * Task: pima
## * Learner: classif.rpart.tuned
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
print(rr)
## <ResampleResult> of 3 iterations
## * Task: pima
## * Learner: classif.rpart.tuned
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

查看内部抽样表现:

extract_inner_tuning_results(rr1)
##    iteration    cp classif.ce learner_param_vals  x_domain task_id
## 1:         1 0.100  0.2578125          <list[2]> <list[1]>    pima
## 2:         2 0.012  0.2500000          <list[2]> <list[1]>    pima
## 3:         3 0.089  0.2441406          <list[2]> <list[1]>    pima
##             learner_id resampling_id
## 1: classif.rpart.tuned            cv
## 2: classif.rpart.tuned            cv
## 3: classif.rpart.tuned            cv

提取归档资料:

extract_inner_tuning_archives(rr1)
##     iteration    cp classif.ce x_domain_cp runtime_learners           timestamp
##  1:         1 0.100  0.2578125       0.100             0.01 2022-02-27 20:51:35
##  2:         1 0.034  0.2578125       0.034             0.03 2022-02-27 20:51:35
##  3:         1 0.001  0.2832031       0.001             0.04 2022-02-27 20:51:35
##  4:         1 0.023  0.2734375       0.023             0.05 2022-02-27 20:51:35
##  5:         1 0.078  0.2578125       0.078             0.03 2022-02-27 20:51:35
##  6:         1 0.067  0.2578125       0.067             0.04 2022-02-27 20:51:35
##  7:         1 0.012  0.2910156       0.012             0.01 2022-02-27 20:51:35
##  8:         1 0.089  0.2578125       0.089             0.01 2022-02-27 20:51:35
##  9:         1 0.056  0.2578125       0.056             0.03 2022-02-27 20:51:35
## 10:         1 0.045  0.2578125       0.045             0.04 2022-02-27 20:51:35
## 11:         2 0.089  0.2597656       0.089             0.02 2022-02-27 20:51:36
## 12:         2 0.056  0.2597656       0.056             0.03 2022-02-27 20:51:36
## 13:         2 0.100  0.2636719       0.100             0.04 2022-02-27 20:51:36
## 14:         2 0.067  0.2519531       0.067             0.02 2022-02-27 20:51:36
## 15:         2 0.045  0.2558594       0.045             0.02 2022-02-27 20:51:36
## 16:         2 0.001  0.2675781       0.001             0.05 2022-02-27 20:51:36
## 17:         2 0.078  0.2597656       0.078             0.01 2022-02-27 20:51:36
## 18:         2 0.034  0.2558594       0.034             0.04 2022-02-27 20:51:36
## 19:         2 0.012  0.2500000       0.012             0.03 2022-02-27 20:51:36
## 20:         2 0.023  0.2597656       0.023             0.02 2022-02-27 20:51:36
## 21:         3 0.089  0.2441406       0.089             0.02 2022-02-27 20:51:36
## 22:         3 0.034  0.2500000       0.034             0.03 2022-02-27 20:51:37
## 23:         3 0.100  0.2441406       0.100             0.00 2022-02-27 20:51:37
## 24:         3 0.023  0.2617188       0.023             0.04 2022-02-27 20:51:37
## 25:         3 0.067  0.2441406       0.067             0.03 2022-02-27 20:51:37
## 26:         3 0.045  0.2441406       0.045             0.03 2022-02-27 20:51:37
## 27:         3 0.001  0.2832031       0.001             0.03 2022-02-27 20:51:37
## 28:         3 0.078  0.2441406       0.078             0.04 2022-02-27 20:51:37
## 29:         3 0.012  0.2675781       0.012             0.04 2022-02-27 20:51:37
## 30:         3 0.056  0.2441406       0.056             0.02 2022-02-27 20:51:37
##     iteration    cp classif.ce x_domain_cp runtime_learners           timestamp
##     batch_nr warnings errors      resample_result task_id          learner_id
##  1:        1        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
##  2:        2        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
##  3:        3        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
##  4:        4        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
##  5:        5        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
##  6:        6        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
##  7:        7        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
##  8:        8        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
##  9:        9        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 10:       10        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 11:        1        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 12:        2        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 13:        3        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 14:        4        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 15:        5        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 16:        6        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 17:        7        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 18:        8        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 19:        9        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 20:       10        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 21:        1        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 22:        2        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 23:        3        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 24:        4        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 25:        5        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 26:        6        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 27:        7        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 28:        8        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 29:        9        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
## 30:       10        0      0 <ResampleResult[22]>    pima classif.rpart.tuned
##     batch_nr warnings errors      resample_result task_id          learner_id
##     resampling_id
##  1:            cv
##  2:            cv
##  3:            cv
##  4:            cv
##  5:            cv
##  6:            cv
##  7:            cv
##  8:            cv
##  9:            cv
## 10:            cv
## 11:            cv
## 12:            cv
## 13:            cv
## 14:            cv
## 15:            cv
## 16:            cv
## 17:            cv
## 18:            cv
## 19:            cv
## 20:            cv
## 21:            cv
## 22:            cv
## 23:            cv
## 24:            cv
## 25:            cv
## 26:            cv
## 27:            cv
## 28:            cv
## 29:            cv
## 30:            cv
##     resampling_id

查看模型表现:

rr1$aggregate()
## classif.ce 
##  0.2682292

rr1$score()
##                 task task_id         learner          learner_id
## 1: <TaskClassif[49]>    pima <AutoTuner[41]> classif.rpart.tuned
## 2: <TaskClassif[49]>    pima <AutoTuner[41]> classif.rpart.tuned
## 3: <TaskClassif[49]>    pima <AutoTuner[41]> classif.rpart.tuned
##            resampling resampling_id iteration              prediction
## 1: <ResamplingCV[19]>            cv         1 <PredictionClassif[20]>
## 2: <ResamplingCV[19]>            cv         2 <PredictionClassif[20]>
## 3: <ResamplingCV[19]>            cv         3 <PredictionClassif[20]>
##    classif.ce
## 1:  0.2539062
## 2:  0.2578125
## 3:  0.2929688

注意,使用tune_nested()之后,并没有提供方法应用于新的数据集,在咨询开发者之后,得到的说法是:tune_nested()是一种评估算法在整个数据集中的表现的方法,不是用于挑选合适的超参数的方法。重抽样过程会产生很多超参数组合,不应该用于模型中。

Hyperband调参

Hyperband调参可看做是一种特殊的随机搜索方式,俗话说:“鱼与熊掌不可兼得”,Hyperband就是取其一种,感兴趣的小伙伴可以自己学习一下。

在这里举一个简单的小例子说明:
假如你有8匹马,每匹马需要4个单位的食物才能发挥最好,但是你现在只有32个单位的食物,所以你需要制定一个策略,充分利用32个单位的食物(也就是你的计算资源)来找到最好的马。
两种策略,第一种:直接放弃4匹马,把所有的食物用在另外4匹马上,这样到最后你就能挑选出4匹马中最好的一匹。但是这样的问题就是你不知道被你舍弃的那4匹马会不会有更好的。
第2种策略:在最开始时每匹马给1个单位食物,然后看它们表现,把表现好的4匹留下,表现不好的就舍弃,给予剩下4匹马更多的食物,然后再把表现好的2匹留下,如此循环,最好把剩下的食物给最后1匹马。

我们主要介绍通过mlr3hyperband包实现这一方法。

library(mlr3verse)

set.seed(123)

ll = po("subsample") %>>% lrn("classif.rpart") # mlr3自带的管道符,先进行预处理

search_space = ps(
  classif.rpart.cp = p_dbl(lower = 0.001, upper = 0.1),
  classif.rpart.minsplit = p_int(lower = 1, upper = 10),
  subsample.frac = p_dbl(lower = 0.1, upper = 1, tags = "budget")
) # tags标记

instance = TuningInstanceSingleCrit$new(
  task = tsk("iris"),
  learner = ll,
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  terminator = trm("none"), # hyperband terminates itself
  search_space = search_space
)

接下来进行hyperband调参:

library(mlr3hyperband)

tuner <- tnr("hyperband", eta = 3)

lgr::get_logger("bbotk")$set_threshold("warn")

tuner$optimize(instance)
## INFO  [20:51:38.099] [mlr3] Running benchmark with 9 resampling iterations 
## INFO  [20:51:39.424] [mlr3] Finished benchmark
##    classif.rpart.cp classif.rpart.minsplit subsample.frac learner_param_vals
## 1:       0.07348139                      5      0.1111111          <list[6]>
##     x_domain classif.ce
## 1: <list[3]>       0.02

查看结果:

instance$result
##    classif.rpart.cp classif.rpart.minsplit subsample.frac learner_param_vals
## 1:       0.07348139                      5      0.1111111          <list[6]>
##     x_domain classif.ce
## 1: <list[3]>       0.02
instance$result_learner_param_vals
## $subsample.frac
## [1] 0.1111111
## 
## $subsample.stratify
## [1] FALSE
## 
## $subsample.replace
## [1] FALSE
## 
## $classif.rpart.xval
## [1] 0
## 
## $classif.rpart.cp
## [1] 0.07348139
## 
## $classif.rpart.minsplit
## [1] 5
instance$result_y
## classif.ce 
##       0.02

特征选择

特征选择也是一门艺术,当我们拿到一份数据时,有很多信息是冗余的,是无效的,对于建模是没有帮助的。这样的变量用于建模只会增加噪声,降低模型表现。把冗余信息去除,挑选最合适的变量的过程被称为特征选择

filters

这种方法首先把所有预测变量计算一个分数,然后按照分数进行排名,这样我们就可以根据分数挑选合适的预测变量了。

查看支持的计算分数的方法:

mlr_filters
## <DictionaryFilter> with 20 stored values
## Keys: anova, auc, carscore, cmim, correlation, disr, find_correlation,
##   importance, information_gain, jmi, jmim, kruskal_test, mim, mrmr,
##   njmim, performance, permutation, relief, selected_features, variance

特征工程是很复杂的,想要详细了解的可阅读相关书籍。

计算分数

目前只支持分类和回归。

filter <- flt("jmim")

task <- tsk("iris")
filter$calculate(task)

filter
## <FilterJMIM:jmim>
## Task Types: classif, regr
## Task Properties: -
## Packages: mlr3filters, praznik
## Feature types: integer, numeric, factor, ordered
##         feature     score
## 1:  Petal.Width 1.0000000
## 2: Sepal.Length 0.6666667
## 3: Petal.Length 0.3333333
## 4:  Sepal.Width 0.0000000

可以看到每个变量都计算出来一个分数。

# 根据相关性挑选变量
filter_cor <- flt("correlation")

# 支持更改参数,默认是pearson
filter_cor$param_set
## <ParamSet>
##        id    class lower upper nlevels    default value
## 1:    use ParamFct    NA    NA       5 everything      
## 2: method ParamFct    NA    NA       3    pearson
# 可以更改为spearman
filter_cor$param_set$values <- list(method = "spearman")
filter_cor$param_set
## <ParamSet>
##        id    class lower upper nlevels    default    value
## 1:    use ParamFct    NA    NA       5 everything         
## 2: method ParamFct    NA    NA       3    pearson spearman

计算变量重要性

所有支持importance参数的learner都支持这种方法。

比如:

lrn <- lrn("classif.ranger", importance = "impurity")

task <- tsk("iris")
filter <- flt("importance", learner = lrn)
filter$calculate(task)
filter
## <FilterImportance:importance>
## Task Types: classif
## Task Properties: -
## Packages: mlr3filters, mlr3, mlr3learners, ranger
## Feature types: logical, integer, numeric, character, factor, ordered
##         feature     score
## 1: Petal.Length 44.420716
## 2:  Petal.Width 43.235616
## 3: Sepal.Length  9.470614
## 4:  Sepal.Width  2.180197

组合方法(wrapper methods)

和超参数调优很相似,mlr3fselect包提供支持。

library(mlr3fselect)

task <- tsk("pima")
learner <- lrn("classif.rpart")
hout <- rsmp("holdout")
measure <- msr("classif.ce")

evals20 <- trm("evals", n_evals = 20) # 设置何时停止

# 构建实例
instance <- FSelectInstanceSingleCrit$new(
  task = task,
  learner = learner,
  resampling = hout,
  measure = measure,
  terminator = evals20
)
instance
## <FSelectInstanceSingleCrit>
## * State:  Not optimized
## * Objective: <ObjectiveFSelect:classif.rpart_on_pima>
## * Search Space:
## <ParamSet>
##          id    class lower upper nlevels        default value
## 1:      age ParamLgl    NA    NA       2 <NoDefault[3]>      
## 2:  glucose ParamLgl    NA    NA       2 <NoDefault[3]>      
## 3:  insulin ParamLgl    NA    NA       2 <NoDefault[3]>      
## 4:     mass ParamLgl    NA    NA       2 <NoDefault[3]>      
## 5: pedigree ParamLgl    NA    NA       2 <NoDefault[3]>      
## 6: pregnant ParamLgl    NA    NA       2 <NoDefault[3]>      
## 7: pressure ParamLgl    NA    NA       2 <NoDefault[3]>      
## 8:  triceps ParamLgl    NA    NA       2 <NoDefault[3]>      
## * Terminator: <TerminatorEvals>
## * Terminated: FALSE
## * Archive:
## <ArchiveFSelect>
## Null data.table (0 rows and 0 cols)

目前mlr3fselect支持以下方法:

  • Random Search(FSelectRandomSearch)
  • Exhaustive Search (FSelectorExhaustiveSearch)
  • Sequential Search (FSelectorSequential)
  • Recursive Feature Elimination (FSelectorRFE)
  • Design Points (FSelectorDesignPoints)

我们挑选一个随机搜索:

fselector <- fs("random_search")

开始运行:

lgr::get_logger("bbotk")$set_threshold("warn")

fselector$optimize(instance)
## INFO  [20:51:39.787] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [20:51:41.955] [mlr3] Finished benchmark
##      age glucose insulin  mass pedigree pregnant pressure triceps features
## 1: FALSE    TRUE   FALSE FALSE    FALSE    FALSE    FALSE   FALSE  glucose
##    classif.ce
## 1:     0.1875

查看选中的变量:

instance$result_feature_set
## [1] "glucose"

查看结果:

instance$result_y
## classif.ce 
##     0.1875
as.data.table(instance$archive)
##       age glucose insulin  mass pedigree pregnant pressure triceps classif.ce
##  1: FALSE   FALSE   FALSE FALSE    FALSE    FALSE     TRUE   FALSE  0.3828125
##  2:  TRUE   FALSE    TRUE FALSE     TRUE     TRUE     TRUE    TRUE  0.3593750
##  3: FALSE   FALSE    TRUE FALSE    FALSE    FALSE    FALSE   FALSE  0.2890625
##  4: FALSE    TRUE    TRUE  TRUE    FALSE    FALSE     TRUE   FALSE  0.2343750
##  5: FALSE    TRUE   FALSE  TRUE    FALSE     TRUE    FALSE   FALSE  0.2226562
##  6: FALSE    TRUE   FALSE FALSE    FALSE    FALSE    FALSE   FALSE  0.1875000
##  7: FALSE    TRUE   FALSE  TRUE    FALSE    FALSE    FALSE   FALSE  0.2226562
##  8: FALSE   FALSE    TRUE FALSE    FALSE     TRUE    FALSE   FALSE  0.2812500
##  9:  TRUE    TRUE    TRUE  TRUE     TRUE     TRUE     TRUE    TRUE  0.2265625
## 10:  TRUE   FALSE   FALSE FALSE    FALSE     TRUE     TRUE   FALSE  0.3085938
## 11:  TRUE    TRUE   FALSE FALSE    FALSE    FALSE    FALSE    TRUE  0.2343750
## 12: FALSE    TRUE   FALSE FALSE     TRUE    FALSE    FALSE    TRUE  0.2460938
## 13:  TRUE    TRUE    TRUE  TRUE    FALSE     TRUE     TRUE    TRUE  0.2539062
## 14: FALSE    TRUE   FALSE FALSE     TRUE    FALSE     TRUE    TRUE  0.2148438
## 15: FALSE    TRUE    TRUE  TRUE     TRUE     TRUE    FALSE    TRUE  0.2226562
## 16: FALSE   FALSE    TRUE FALSE     TRUE     TRUE    FALSE   FALSE  0.2968750
## 17: FALSE    TRUE   FALSE FALSE    FALSE    FALSE    FALSE   FALSE  0.1875000
## 18: FALSE   FALSE    TRUE  TRUE     TRUE     TRUE     TRUE   FALSE  0.3750000
## 19: FALSE    TRUE    TRUE  TRUE     TRUE     TRUE     TRUE    TRUE  0.2343750
## 20:  TRUE   FALSE    TRUE FALSE     TRUE     TRUE     TRUE    TRUE  0.3593750
##     runtime_learners           timestamp batch_nr      resample_result
##  1:             0.03 2022-02-27 20:51:39        1 <ResampleResult[22]>
##  2:             0.05 2022-02-27 20:51:39        2 <ResampleResult[22]>
##  3:             0.03 2022-02-27 20:51:40        3 <ResampleResult[22]>
##  4:             0.03 2022-02-27 20:51:40        4 <ResampleResult[22]>
##  5:             0.03 2022-02-27 20:51:40        5 <ResampleResult[22]>
##  6:             0.03 2022-02-27 20:51:40        6 <ResampleResult[22]>
##  7:             0.04 2022-02-27 20:51:40        7 <ResampleResult[22]>
##  8:             0.04 2022-02-27 20:51:40        8 <ResampleResult[22]>
##  9:             0.03 2022-02-27 20:51:40        9 <ResampleResult[22]>
## 10:             0.03 2022-02-27 20:51:40       10 <ResampleResult[22]>
## 11:             0.03 2022-02-27 20:51:40       11 <ResampleResult[22]>
## 12:             0.03 2022-02-27 20:51:41       12 <ResampleResult[22]>
## 13:             0.05 2022-02-27 20:51:41       13 <ResampleResult[22]>
## 14:             0.05 2022-02-27 20:51:41       14 <ResampleResult[22]>
## 15:             0.03 2022-02-27 20:51:41       15 <ResampleResult[22]>
## 16:             0.03 2022-02-27 20:51:41       16 <ResampleResult[22]>
## 17:             0.04 2022-02-27 20:51:41       17 <ResampleResult[22]>
## 18:             0.05 2022-02-27 20:51:41       18 <ResampleResult[22]>
## 19:             0.03 2022-02-27 20:51:41       19 <ResampleResult[22]>
## 20:             0.04 2022-02-27 20:51:41       20 <ResampleResult[22]>
instance$archive$benchmark_result$data
## <ResultData>
##   Public:
##     as_data_table: function (view = NULL, reassemble_learners = TRUE, convert_predictions = TRUE, 
##     clone: function (deep = FALSE) 
##     combine: function (rdata) 
##     data: list
##     discard: function (backends = FALSE, models = FALSE) 
##     initialize: function (data = NULL, store_backends = TRUE) 
##     iterations: function (view = NULL) 
##     learner_states: function (view = NULL) 
##     learners: function (view = NULL, states = TRUE, reassemble = TRUE) 
##     logs: function (view = NULL, condition) 
##     prediction: function (view = NULL, predict_sets = "test") 
##     predictions: function (view = NULL, predict_sets = "test") 
##     resamplings: function (view = NULL) 
##     sweep: function () 
##     task_type: active binding
##     tasks: function (view = NULL) 
##     uhashes: function (view = NULL) 
##   Private:
##     deep_clone: function (name, value) 
##     get_view_index: function (view)

应用于模型,训练任务:

task$select(instance$result_feature_set) # 只使用选中的变量
learner$train(task)

自动选择

learner = lrn("classif.rpart")
terminator = trm("evals", n_evals = 10)
fselector = fs("random_search")

at = AutoFSelector$new(
  learner = learner,
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  terminator = terminator,
  fselector = fselector
)
at
## <AutoFSelector:classif.rpart.fselector>
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, mlr3fselect, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

比较不同的子集得到的模型表现:

grid = benchmark_grid(
  task = tsk("pima"),
  learner = list(at, lrn("classif.rpart")),
  resampling = rsmp("cv", folds = 3)
)

bmr = benchmark(grid, store_models = TRUE)
## INFO  [20:51:42.111] [mlr3] Running benchmark with 6 resampling iterations 
## INFO  [20:51:45.620] [mlr3] Finished benchmark 
## INFO  [20:51:45.672] [mlr3] Finished benchmark
bmr$aggregate(msrs(c("classif.ce", "time_train")))
##    nr      resample_result task_id              learner_id resampling_id iters
## 1:  1 <ResampleResult[22]>    pima classif.rpart.fselector            cv     3
## 2:  2 <ResampleResult[22]>    pima           classif.rpart            cv     3
##    classif.ce time_train
## 1:  0.2539062          0
## 2:  0.2539062          0

获取更多R语言知识,请关注公众号:医学和生信笔记

医学和生信笔记 公众号主要分享:1.医学小知识、肛肠科小知识;2.R语言和Python相关的数据分析、可视化、机器学习等;3.生物信息学学习资料和自己的学习笔记!