Java 类名:com.alibaba.alink.pipeline.tuning.RandomSearchCV
Python 类名:RandomSearchCV

功能介绍

randomsearch是通过随机参数,对其中的每一组输入参数的组很分别进行训练,预测,评估。取得评估参数最优的模型,作为最终的返回模型
cv为交叉验证,将数据切分为k-folds,对每k-1份数据做训练,对剩余一份数据做预测和评估,得到一个评估结果。
此函数用cv方法得到每一个grid对应参数的评估结果,得到最优模型

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 默认值 | | —- | —- | —- | —- | —- | —- |

| NumFolds | 折数 | 交叉验证的参数,数据的折数(大于等于2) | Integer | | 10 |

| ParamDist | 参数分布 | 指定搜索的参数的分布 | ParamDist | ✓ | —- |

| Estimator | Estimator | 用于调优的Estimator | Estimator | ✓ | —- |

| TuningEvaluator | 评估指标 | 用于选择最优模型的评估指标 | TuningEvaluator | ✓ | —- |

代码示例

Python 代码

  1. from pyalink.alink import *
  2. import pandas as pd
  3. useLocalEnv(1)
  4. def adult(url):
  5. data = (
  6. CsvSourceBatchOp()
  7. .setFilePath('https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv')
  8. .setSchemaStr(
  9. 'age bigint, workclass string, fnlwgt bigint,'
  10. 'education string, education_num bigint,'
  11. 'marital_status string, occupation string,'
  12. 'relationship string, race string, sex string,'
  13. 'capital_gain bigint, capital_loss bigint,'
  14. 'hours_per_week bigint, native_country string,'
  15. 'label string'
  16. )
  17. )
  18. return data
  19. def adult_train():
  20. return adult('https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv')
  21. def adult_test():
  22. return adult('https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_test.csv')
  23. def adult_numerical_feature_strs():
  24. return [
  25. "age", "fnlwgt", "education_num",
  26. "capital_gain", "capital_loss", "hours_per_week"
  27. ]
  28. def adult_categorical_feature_strs():
  29. return [
  30. "workclass", "education", "marital_status",
  31. "occupation", "relationship", "race", "sex",
  32. "native_country"
  33. ]
  34. def adult_features_strs():
  35. feature = adult_numerical_feature_strs()
  36. feature.extend(adult_categorical_feature_strs())
  37. return feature
  38. def rf_grid_search_cv(featureCols, categoryFeatureCols, label, metric):
  39. rf = (
  40. RandomForestClassifier()
  41. .setFeatureCols(featureCols)
  42. .setCategoricalCols(categoryFeatureCols)
  43. .setLabelCol(label)
  44. .setPredictionCol('prediction')
  45. .setPredictionDetailCol('prediction_detail')
  46. )
  47. paramDist = (
  48. ParamDist()
  49. .addDist(rf, 'NUM_TREES', ValueDist.randInteger(1, 10))
  50. )
  51. tuningEvaluator = (
  52. BinaryClassificationTuningEvaluator()
  53. .setLabelCol(label)
  54. .setPredictionDetailCol("prediction_detail")
  55. .setTuningBinaryClassMetric(metric)
  56. )
  57. cv = (
  58. RandomSearchCV()
  59. .setEstimator(rf)
  60. .setParamDist(paramDist)
  61. .setTuningEvaluator(tuningEvaluator)
  62. .setNumFolds(2)
  63. )
  64. return cv
  65. def rf_grid_search_tv(featureCols, categoryFeatureCols, label, metric):
  66. rf = (
  67. RandomForestClassifier()
  68. .setFeatureCols(featureCols)
  69. .setCategoricalCols(categoryFeatureCols)
  70. .setLabelCol(label)
  71. .setPredictionCol('prediction')
  72. .setPredictionDetailCol('prediction_detail')
  73. )
  74. paramDist = (
  75. ParamDist()
  76. .addDist(rf, 'NUM_TREES', ValueDist.randInteger(1, 10))
  77. )
  78. tuningEvaluator = (
  79. BinaryClassificationTuningEvaluator()
  80. .setLabelCol(label)
  81. .setPredictionDetailCol("prediction_detail")
  82. .setTuningBinaryClassMetric(metric)
  83. )
  84. cv = (
  85. RandomSearchTVSplit()
  86. .setEstimator(rf)
  87. .setParamDist(paramDist)
  88. .setTuningEvaluator(tuningEvaluator)
  89. )
  90. return cv
  91. def tuningcv(cv_estimator, input):
  92. return cv_estimator.enableLazyPrintTrainInfo("CVTrainInfo").fit(input)
  93. def tuningtv(tv_estimator, input):
  94. return tv_estimator.enableLazyPrintTrainInfo("TVTrainInfo").fit(input)
  95. def main():
  96. print('rf cv tuning')
  97. model = tuningcv(
  98. rf_grid_search_cv(adult_features_strs(),
  99. adult_categorical_feature_strs(), 'label', 'AUC'),
  100. adult_train()
  101. )
  102. print('rf tv tuning')
  103. model = tuningtv(
  104. rf_grid_search_tv(adult_features_strs(),
  105. adult_categorical_feature_strs(), 'label', 'AUC'),
  106. adult_train()
  107. )
  108. main()

Java 代码

  1. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  2. import com.alibaba.alink.pipeline.classification.RandomForestClassifier;
  3. import com.alibaba.alink.pipeline.tuning.BinaryClassificationTuningEvaluator;
  4. import com.alibaba.alink.pipeline.tuning.ParamDist;
  5. import com.alibaba.alink.pipeline.tuning.RandomSearchCV;
  6. import com.alibaba.alink.pipeline.tuning.RandomSearchCVModel;
  7. import com.alibaba.alink.pipeline.tuning.ValueDist;
  8. import org.junit.Test;
  9. public class RandomSearchCVTest {
  10. @Test
  11. public void testRandomSearchCV() throws Exception {
  12. String[] featureCols = new String[] {
  13. "age", "fnlwgt", "education_num",
  14. "capital_gain", "capital_loss", "hours_per_week",
  15. "workclass", "education", "marital_status",
  16. "occupation", "relationship", "race", "sex",
  17. "native_country"
  18. };
  19. String[] categoryFeatureCols = new String[] {
  20. "workclass", "education", "marital_status",
  21. "occupation", "relationship", "race", "sex",
  22. "native_country"
  23. };
  24. String label = "label";
  25. CsvSourceBatchOp data = new CsvSourceBatchOp()
  26. .setFilePath("https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv")
  27. .setSchemaStr(
  28. "age bigint, workclass string, fnlwgt bigint, education string, education_num bigint, marital_status "
  29. + "string, occupation string, relationship string, race string, sex string, capital_gain bigint, "
  30. + "capital_loss bigint, hours_per_week bigint, native_country string, label string");
  31. RandomForestClassifier rf = new RandomForestClassifier()
  32. .setFeatureCols(featureCols)
  33. .setCategoricalCols(categoryFeatureCols)
  34. .setLabelCol(label)
  35. .setPredictionCol("prediction")
  36. .setPredictionDetailCol("prediction_detail");
  37. ParamDist paramDist = new ParamDist()
  38. .addDist(rf, RandomForestClassifier.NUM_TREES, ValueDist.randInteger(1, 10));
  39. BinaryClassificationTuningEvaluator tuningEvaluator = new BinaryClassificationTuningEvaluator()
  40. .setLabelCol(label)
  41. .setPredictionDetailCol("prediction_detail")
  42. .setTuningBinaryClassMetric("AUC");
  43. RandomSearchCV cv = new RandomSearchCV()
  44. .setEstimator(rf)
  45. .setParamDist(paramDist)
  46. .setTuningEvaluator(tuningEvaluator)
  47. .setNumFolds(2)
  48. .enableLazyPrintTrainInfo("TrainInfo");
  49. RandomSearchCVModel model = cv.fit(data);
  50. }
  51. }

运行结果

TrainInfo
Metric information:
Metric name: AUC
Larger is better: true
Tuning information:

| AUC | stage | param | value | | —- | —- | —- | —- |

| 0.9134148020313496 | RandomForestClassifier | numTrees | 10 |

| 0.9123992401477525 | RandomForestClassifier | numTrees | 10 |

| 0.9107724678432794 | RandomForestClassifier | numTrees | 8 |

| 0.905703319906151 | RandomForestClassifier | numTrees | 6 |

| 0.9052924036494705 | RandomForestClassifier | numTrees | 7 |

| 0.8927397325721704 | RandomForestClassifier | numTrees | 3 |

| 0.8887150253364192 | RandomForestClassifier | numTrees | 3 |

| 0.885191174049819 | RandomForestClassifier | numTrees | 3 |

| 0.8837444737636566 | RandomForestClassifier | numTrees | 2 |

| 0.8774725529763574 | RandomForestClassifier | numTrees | 2 |

运行结果

  1. rf cv tuning
  2. com.alibaba.alink.pipeline.tuning.GridSearchCV
  3. [ {
  4. "param" : [ {
  5. "stage" : "RandomForestClassifier",
  6. "paramName" : "numTrees",
  7. "paramValue" : 3
  8. }, {
  9. "stage" : "RandomForestClassifier",
  10. "paramName" : "subsamplingRatio",
  11. "paramValue" : 1.0
  12. } ],
  13. "metric" : 0.8922549257899725
  14. }, {
  15. "param" : [ {
  16. "stage" : "RandomForestClassifier",
  17. "paramName" : "numTrees",
  18. "paramValue" : 3
  19. }, {
  20. "stage" : "RandomForestClassifier",
  21. "paramName" : "subsamplingRatio",
  22. "paramValue" : 0.99
  23. } ],
  24. "metric" : 0.8920255970548456
  25. }, {
  26. "param" : [ {
  27. "stage" : "RandomForestClassifier",
  28. "paramName" : "numTrees",
  29. "paramValue" : 3
  30. }, {
  31. "stage" : "RandomForestClassifier",
  32. "paramName" : "subsamplingRatio",
  33. "paramValue" : 0.98
  34. } ],
  35. "metric" : 0.8944982480437225
  36. }, {
  37. "param" : [ {
  38. "stage" : "RandomForestClassifier",
  39. "paramName" : "numTrees",
  40. "paramValue" : 6
  41. }, {
  42. "stage" : "RandomForestClassifier",
  43. "paramName" : "subsamplingRatio",
  44. "paramValue" : 1.0
  45. } ],
  46. "metric" : 0.8923867598288401
  47. }, {
  48. "param" : [ {
  49. "stage" : "RandomForestClassifier",
  50. "paramName" : "numTrees",
  51. "paramValue" : 6
  52. }, {
  53. "stage" : "RandomForestClassifier",
  54. "paramName" : "subsamplingRatio",
  55. "paramValue" : 0.99
  56. } ],
  57. "metric" : 0.9012141767959505
  58. }, {
  59. "param" : [ {
  60. "stage" : "RandomForestClassifier",
  61. "paramName" : "numTrees",
  62. "paramValue" : 6
  63. }, {
  64. "stage" : "RandomForestClassifier",
  65. "paramName" : "subsamplingRatio",
  66. "paramValue" : 0.98
  67. } ],
  68. "metric" : 0.8993774036693788
  69. }, {
  70. "param" : [ {
  71. "stage" : "RandomForestClassifier",
  72. "paramName" : "numTrees",
  73. "paramValue" : 9
  74. }, {
  75. "stage" : "RandomForestClassifier",
  76. "paramName" : "subsamplingRatio",
  77. "paramValue" : 1.0
  78. } ],
  79. "metric" : 0.8981738808130779
  80. }, {
  81. "param" : [ {
  82. "stage" : "RandomForestClassifier",
  83. "paramName" : "numTrees",
  84. "paramValue" : 9
  85. }, {
  86. "stage" : "RandomForestClassifier",
  87. "paramName" : "subsamplingRatio",
  88. "paramValue" : 0.99
  89. } ],
  90. "metric" : 0.9029671873892725
  91. }, {
  92. "param" : [ {
  93. "stage" : "RandomForestClassifier",
  94. "paramName" : "numTrees",
  95. "paramValue" : 9
  96. }, {
  97. "stage" : "RandomForestClassifier",
  98. "paramName" : "subsamplingRatio",
  99. "paramValue" : 0.98
  100. } ],
  101. "metric" : 0.905228896323363
  102. } ]
  103. rf tv tuning
  104. com.alibaba.alink.pipeline.tuning.GridSearchTVSplit
  105. [ {
  106. "param" : [ {
  107. "stage" : "RandomForestClassifier",
  108. "paramName" : "numTrees",
  109. "paramValue" : 3
  110. }, {
  111. "stage" : "RandomForestClassifier",
  112. "paramName" : "subsamplingRatio",
  113. "paramValue" : 1.0
  114. } ],
  115. "metric" : 0.9022694229691741
  116. }, {
  117. "param" : [ {
  118. "stage" : "RandomForestClassifier",
  119. "paramName" : "numTrees",
  120. "paramValue" : 3
  121. }, {
  122. "stage" : "RandomForestClassifier",
  123. "paramName" : "subsamplingRatio",
  124. "paramValue" : 0.99
  125. } ],
  126. "metric" : 0.8963559966080328
  127. }, {
  128. "param" : [ {
  129. "stage" : "RandomForestClassifier",
  130. "paramName" : "numTrees",
  131. "paramValue" : 3
  132. }, {
  133. "stage" : "RandomForestClassifier",
  134. "paramName" : "subsamplingRatio",
  135. "paramValue" : 0.98
  136. } ],
  137. "metric" : 0.9041948454957178
  138. }, {
  139. "param" : [ {
  140. "stage" : "RandomForestClassifier",
  141. "paramName" : "numTrees",
  142. "paramValue" : 6
  143. }, {
  144. "stage" : "RandomForestClassifier",
  145. "paramName" : "subsamplingRatio",
  146. "paramValue" : 1.0
  147. } ],
  148. "metric" : 0.8982021117392784
  149. }, {
  150. "param" : [ {
  151. "stage" : "RandomForestClassifier",
  152. "paramName" : "numTrees",
  153. "paramValue" : 6
  154. }, {
  155. "stage" : "RandomForestClassifier",
  156. "paramName" : "subsamplingRatio",
  157. "paramValue" : 0.99
  158. } ],
  159. "metric" : 0.9031851535310546
  160. }, {
  161. "param" : [ {
  162. "stage" : "RandomForestClassifier",
  163. "paramName" : "numTrees",
  164. "paramValue" : 6
  165. }, {
  166. "stage" : "RandomForestClassifier",
  167. "paramName" : "subsamplingRatio",
  168. "paramValue" : 0.98
  169. } ],
  170. "metric" : 0.9034443322241488
  171. }, {
  172. "param" : [ {
  173. "stage" : "RandomForestClassifier",
  174. "paramName" : "numTrees",
  175. "paramValue" : 9
  176. }, {
  177. "stage" : "RandomForestClassifier",
  178. "paramName" : "subsamplingRatio",
  179. "paramValue" : 1.0
  180. } ],
  181. "metric" : 0.8993474753000145
  182. }, {
  183. "param" : [ {
  184. "stage" : "RandomForestClassifier",
  185. "paramName" : "numTrees",
  186. "paramValue" : 9
  187. }, {
  188. "stage" : "RandomForestClassifier",
  189. "paramName" : "subsamplingRatio",
  190. "paramValue" : 0.99
  191. } ],
  192. "metric" : 0.9090250137144916
  193. }, {
  194. "param" : [ {
  195. "stage" : "RandomForestClassifier",
  196. "paramName" : "numTrees",
  197. "paramValue" : 9
  198. }, {
  199. "stage" : "RandomForestClassifier",
  200. "paramName" : "subsamplingRatio",
  201. "paramValue" : 0.98
  202. } ],
  203. "metric" : 0.9129786771786127
  204. } ]