Java 类名:com.alibaba.alink.pipeline.tuning.GridSearchCV
Python 类名:GridSearchCV

功能介绍

gridsearch是通过参数数组组成的网格,对其中的每一组输入参数的组很分别进行训练,预测,评估。取得评估参数最优的模型,作为最终的返回模型
cv为交叉验证,将数据切分为k-folds,对每k-1份数据做训练,对剩余一份数据做预测和评估,得到一个评估结果。
此函数用cv方法得到每一个grid对应参数的评估结果,得到最优模型

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 默认值 | | —- | —- | —- | —- | —- | —- |

| NumFolds | 折数 | 交叉验证的参数,数据的折数(大于等于2) | Integer | | 10 |

| ParamGrid | 参数网格 | 指定参数的网格 | ParamGrid | ✓ | —- |

| Estimator | Estimator | 用于调优的Estimator | Estimator | ✓ | —- |

| TuningEvaluator | 评估指标 | 用于选择最优模型的评估指标 | TuningEvaluator | ✓ | —- |

代码示例

Python 代码

  1. from pyalink.alink import *
  2. import pandas as pd
  3. useLocalEnv(1)
  4. def adult(url):
  5. data = (
  6. CsvSourceBatchOp()
  7. .setFilePath('https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv')
  8. .setSchemaStr(
  9. 'age bigint, workclass string, fnlwgt bigint,'
  10. 'education string, education_num bigint,'
  11. 'marital_status string, occupation string,'
  12. 'relationship string, race string, sex string,'
  13. 'capital_gain bigint, capital_loss bigint,'
  14. 'hours_per_week bigint, native_country string,'
  15. 'label string'
  16. )
  17. )
  18. return data
  19. def adult_train():
  20. return adult('https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv')
  21. def adult_test():
  22. return adult('https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_test.csv')
  23. def adult_numerical_feature_strs():
  24. return [
  25. "age", "fnlwgt", "education_num",
  26. "capital_gain", "capital_loss", "hours_per_week"
  27. ]
  28. def adult_categorical_feature_strs():
  29. return [
  30. "workclass", "education", "marital_status",
  31. "occupation", "relationship", "race", "sex",
  32. "native_country"
  33. ]
  34. def adult_features_strs():
  35. feature = adult_numerical_feature_strs()
  36. feature.extend(adult_categorical_feature_strs())
  37. return feature
  38. def rf_grid_search_cv(featureCols, categoryFeatureCols, label, metric):
  39. rf = (
  40. RandomForestClassifier()
  41. .setFeatureCols(featureCols)
  42. .setCategoricalCols(categoryFeatureCols)
  43. .setLabelCol(label)
  44. .setPredictionCol('prediction')
  45. .setPredictionDetailCol('prediction_detail')
  46. )
  47. paramGrid = (
  48. ParamGrid()
  49. .addGrid(rf, 'SUBSAMPLING_RATIO', [1.0, 0.99, 0.98])
  50. .addGrid(rf, 'NUM_TREES', [3, 6, 9])
  51. )
  52. tuningEvaluator = (
  53. BinaryClassificationTuningEvaluator()
  54. .setLabelCol(label)
  55. .setPredictionDetailCol("prediction_detail")
  56. .setTuningBinaryClassMetric(metric)
  57. )
  58. cv = (
  59. GridSearchCV()
  60. .setEstimator(rf)
  61. .setParamGrid(paramGrid)
  62. .setTuningEvaluator(tuningEvaluator)
  63. .setNumFolds(2)
  64. .enableLazyPrintTrainInfo("TrainInfo")
  65. )
  66. return cv
  67. def rf_grid_search_tv(featureCols, categoryFeatureCols, label, metric):
  68. rf = (
  69. RandomForestClassifier()
  70. .setFeatureCols(featureCols)
  71. .setCategoricalCols(categoryFeatureCols)
  72. .setLabelCol(label)
  73. .setPredictionCol('prediction')
  74. .setPredictionDetailCol('prediction_detail')
  75. )
  76. paramGrid = (
  77. ParamGrid()
  78. .addGrid(rf, 'SUBSAMPLING_RATIO', [1.0, 0.99, 0.98])
  79. .addGrid(rf, 'NUM_TREES', [3, 6, 9])
  80. )
  81. tuningEvaluator = (
  82. BinaryClassificationTuningEvaluator()
  83. .setLabelCol(label)
  84. .setPredictionDetailCol("prediction_detail")
  85. .setTuningBinaryClassMetric(metric)
  86. )
  87. cv = (
  88. GridSearchTVSplit()
  89. .setEstimator(rf)
  90. .setParamGrid(paramGrid)
  91. .setTuningEvaluator(tuningEvaluator)
  92. .enableLazyPrintTrainInfo("TrainInfo")
  93. )
  94. return cv
  95. def tuningcv(cv_estimator, input):
  96. return cv_estimator.fit(input)
  97. def tuningtv(tv_estimator, input):
  98. return tv_estimator.fit(input)
  99. def main():
  100. print('rf cv tuning')
  101. model = tuningcv(
  102. rf_grid_search_cv(adult_features_strs(),
  103. adult_categorical_feature_strs(), 'label', 'AUC'),
  104. adult_train()
  105. )
  106. print('rf tv tuning')
  107. model = tuningtv(
  108. rf_grid_search_tv(adult_features_strs(),
  109. adult_categorical_feature_strs(), 'label', 'AUC'),
  110. adult_train()
  111. )
  112. main()

Java 代码

  1. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  2. import com.alibaba.alink.pipeline.classification.RandomForestClassifier;
  3. import com.alibaba.alink.pipeline.tuning.BinaryClassificationTuningEvaluator;
  4. import com.alibaba.alink.pipeline.tuning.GridSearchCV;
  5. import com.alibaba.alink.pipeline.tuning.GridSearchCVModel;
  6. import com.alibaba.alink.pipeline.tuning.ParamGrid;
  7. import org.junit.Test;
  8. public class GridSearchCVTest {
  9. @Test
  10. public void testGridSearchCV() throws Exception {
  11. String[] featureCols = new String[] {
  12. "age", "fnlwgt", "education_num",
  13. "capital_gain", "capital_loss", "hours_per_week",
  14. "workclass", "education", "marital_status",
  15. "occupation", "relationship", "race", "sex",
  16. "native_country"
  17. };
  18. String[] categoryFeatureCols = new String[] {
  19. "workclass", "education", "marital_status",
  20. "occupation", "relationship", "race", "sex",
  21. "native_country"
  22. };
  23. String label = "label";
  24. CsvSourceBatchOp data = new CsvSourceBatchOp()
  25. .setFilePath("https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv")
  26. .setSchemaStr(
  27. "age bigint, workclass string, fnlwgt bigint, education string, education_num bigint, marital_status "
  28. + "string, occupation string, relationship string, race string, sex string, capital_gain bigint, "
  29. + "capital_loss bigint, hours_per_week bigint, native_country string, label string");
  30. RandomForestClassifier rf = new RandomForestClassifier()
  31. .setFeatureCols(featureCols)
  32. .setCategoricalCols(categoryFeatureCols)
  33. .setLabelCol(label)
  34. .setPredictionCol("prediction")
  35. .setPredictionDetailCol("prediction_detail");
  36. ParamGrid paramGrid = new ParamGrid()
  37. .addGrid(rf, RandomForestClassifier.SUBSAMPLING_RATIO, new Double[] {1.0, 0.99, 0.98})
  38. .addGrid(rf, RandomForestClassifier.NUM_TREES, new Integer[] {3, 6, 9});
  39. BinaryClassificationTuningEvaluator tuningEvaluator = new BinaryClassificationTuningEvaluator()
  40. .setLabelCol(label)
  41. .setPredictionDetailCol("prediction_detail")
  42. .setTuningBinaryClassMetric("AUC");
  43. GridSearchCV cv = new GridSearchCV()
  44. .setEstimator(rf)
  45. .setParamGrid(paramGrid)
  46. .setTuningEvaluator(tuningEvaluator)
  47. .setNumFolds(2)
  48. .enableLazyPrintTrainInfo("TrainInfo");
  49. GridSearchCVModel model = cv.fit(data);
  50. }
  51. }

运行结果

TrainInfo
Metric information:
Metric name: AUC
Larger is better: true
Tuning information:

| AUC | stage | param | value | stage 2 | param 2 | value 2 | | —- | —- | —- | —- | —- | —- | —- |

| 0.912327454540025 | RandomForestClassifier | numTrees | 9 | RandomForestClassifier | subsamplingRatio | 0.98 |

| 0.9113181022628927 | RandomForestClassifier | numTrees | 9 | RandomForestClassifier | subsamplingRatio | 1.0 |

| 0.9109408773009041 | RandomForestClassifier | numTrees | 9 | RandomForestClassifier | subsamplingRatio | 0.99 |

| 0.9084745064874684 | RandomForestClassifier | numTrees | 6 | RandomForestClassifier | subsamplingRatio | 1.0 |

| 0.9066321684664669 | RandomForestClassifier | numTrees | 6 | RandomForestClassifier | subsamplingRatio | 0.98 |

| 0.9045123178682739 | RandomForestClassifier | numTrees | 6 | RandomForestClassifier | subsamplingRatio | 0.99 |

| 0.8908957160768797 | RandomForestClassifier | numTrees | 3 | RandomForestClassifier | subsamplingRatio | 1.0 |

| 0.8903604608878586 | RandomForestClassifier | numTrees | 3 | RandomForestClassifier | subsamplingRatio | 0.98 |

| 0.888885807369439 | RandomForestClassifier | numTrees | 3 | RandomForestClassifier | subsamplingRatio | 0.99 |

运行结果

  1. rf cv tuning
  2. com.alibaba.alink.pipeline.tuning.GridSearchCV
  3. [ {
  4. "param" : [ {
  5. "stage" : "RandomForestClassifier",
  6. "paramName" : "numTrees",
  7. "paramValue" : 3
  8. }, {
  9. "stage" : "RandomForestClassifier",
  10. "paramName" : "subsamplingRatio",
  11. "paramValue" : 1.0
  12. } ],
  13. "metric" : 0.8922549257899725
  14. }, {
  15. "param" : [ {
  16. "stage" : "RandomForestClassifier",
  17. "paramName" : "numTrees",
  18. "paramValue" : 3
  19. }, {
  20. "stage" : "RandomForestClassifier",
  21. "paramName" : "subsamplingRatio",
  22. "paramValue" : 0.99
  23. } ],
  24. "metric" : 0.8920255970548456
  25. }, {
  26. "param" : [ {
  27. "stage" : "RandomForestClassifier",
  28. "paramName" : "numTrees",
  29. "paramValue" : 3
  30. }, {
  31. "stage" : "RandomForestClassifier",
  32. "paramName" : "subsamplingRatio",
  33. "paramValue" : 0.98
  34. } ],
  35. "metric" : 0.8944982480437225
  36. }, {
  37. "param" : [ {
  38. "stage" : "RandomForestClassifier",
  39. "paramName" : "numTrees",
  40. "paramValue" : 6
  41. }, {
  42. "stage" : "RandomForestClassifier",
  43. "paramName" : "subsamplingRatio",
  44. "paramValue" : 1.0
  45. } ],
  46. "metric" : 0.8923867598288401
  47. }, {
  48. "param" : [ {
  49. "stage" : "RandomForestClassifier",
  50. "paramName" : "numTrees",
  51. "paramValue" : 6
  52. }, {
  53. "stage" : "RandomForestClassifier",
  54. "paramName" : "subsamplingRatio",
  55. "paramValue" : 0.99
  56. } ],
  57. "metric" : 0.9012141767959505
  58. }, {
  59. "param" : [ {
  60. "stage" : "RandomForestClassifier",
  61. "paramName" : "numTrees",
  62. "paramValue" : 6
  63. }, {
  64. "stage" : "RandomForestClassifier",
  65. "paramName" : "subsamplingRatio",
  66. "paramValue" : 0.98
  67. } ],
  68. "metric" : 0.8993774036693788
  69. }, {
  70. "param" : [ {
  71. "stage" : "RandomForestClassifier",
  72. "paramName" : "numTrees",
  73. "paramValue" : 9
  74. }, {
  75. "stage" : "RandomForestClassifier",
  76. "paramName" : "subsamplingRatio",
  77. "paramValue" : 1.0
  78. } ],
  79. "metric" : 0.8981738808130779
  80. }, {
  81. "param" : [ {
  82. "stage" : "RandomForestClassifier",
  83. "paramName" : "numTrees",
  84. "paramValue" : 9
  85. }, {
  86. "stage" : "RandomForestClassifier",
  87. "paramName" : "subsamplingRatio",
  88. "paramValue" : 0.99
  89. } ],
  90. "metric" : 0.9029671873892725
  91. }, {
  92. "param" : [ {
  93. "stage" : "RandomForestClassifier",
  94. "paramName" : "numTrees",
  95. "paramValue" : 9
  96. }, {
  97. "stage" : "RandomForestClassifier",
  98. "paramName" : "subsamplingRatio",
  99. "paramValue" : 0.98
  100. } ],
  101. "metric" : 0.905228896323363
  102. } ]
  103. rf tv tuning
  104. com.alibaba.alink.pipeline.tuning.GridSearchTVSplit
  105. [ {
  106. "param" : [ {
  107. "stage" : "RandomForestClassifier",
  108. "paramName" : "numTrees",
  109. "paramValue" : 3
  110. }, {
  111. "stage" : "RandomForestClassifier",
  112. "paramName" : "subsamplingRatio",
  113. "paramValue" : 1.0
  114. } ],
  115. "metric" : 0.9022694229691741
  116. }, {
  117. "param" : [ {
  118. "stage" : "RandomForestClassifier",
  119. "paramName" : "numTrees",
  120. "paramValue" : 3
  121. }, {
  122. "stage" : "RandomForestClassifier",
  123. "paramName" : "subsamplingRatio",
  124. "paramValue" : 0.99
  125. } ],
  126. "metric" : 0.8963559966080328
  127. }, {
  128. "param" : [ {
  129. "stage" : "RandomForestClassifier",
  130. "paramName" : "numTrees",
  131. "paramValue" : 3
  132. }, {
  133. "stage" : "RandomForestClassifier",
  134. "paramName" : "subsamplingRatio",
  135. "paramValue" : 0.98
  136. } ],
  137. "metric" : 0.9041948454957178
  138. }, {
  139. "param" : [ {
  140. "stage" : "RandomForestClassifier",
  141. "paramName" : "numTrees",
  142. "paramValue" : 6
  143. }, {
  144. "stage" : "RandomForestClassifier",
  145. "paramName" : "subsamplingRatio",
  146. "paramValue" : 1.0
  147. } ],
  148. "metric" : 0.8982021117392784
  149. }, {
  150. "param" : [ {
  151. "stage" : "RandomForestClassifier",
  152. "paramName" : "numTrees",
  153. "paramValue" : 6
  154. }, {
  155. "stage" : "RandomForestClassifier",
  156. "paramName" : "subsamplingRatio",
  157. "paramValue" : 0.99
  158. } ],
  159. "metric" : 0.9031851535310546
  160. }, {
  161. "param" : [ {
  162. "stage" : "RandomForestClassifier",
  163. "paramName" : "numTrees",
  164. "paramValue" : 6
  165. }, {
  166. "stage" : "RandomForestClassifier",
  167. "paramName" : "subsamplingRatio",
  168. "paramValue" : 0.98
  169. } ],
  170. "metric" : 0.9034443322241488
  171. }, {
  172. "param" : [ {
  173. "stage" : "RandomForestClassifier",
  174. "paramName" : "numTrees",
  175. "paramValue" : 9
  176. }, {
  177. "stage" : "RandomForestClassifier",
  178. "paramName" : "subsamplingRatio",
  179. "paramValue" : 1.0
  180. } ],
  181. "metric" : 0.8993474753000145
  182. }, {
  183. "param" : [ {
  184. "stage" : "RandomForestClassifier",
  185. "paramName" : "numTrees",
  186. "paramValue" : 9
  187. }, {
  188. "stage" : "RandomForestClassifier",
  189. "paramName" : "subsamplingRatio",
  190. "paramValue" : 0.99
  191. } ],
  192. "metric" : 0.9090250137144916
  193. }, {
  194. "param" : [ {
  195. "stage" : "RandomForestClassifier",
  196. "paramName" : "numTrees",
  197. "paramValue" : 9
  198. }, {
  199. "stage" : "RandomForestClassifier",
  200. "paramName" : "subsamplingRatio",
  201. "paramValue" : 0.98
  202. } ],
  203. "metric" : 0.9129786771786127
  204. } ]