Java 类名:com.alibaba.alink.pipeline.tuning.RandomSearchTVSplit
Python 类名:RandomSearchTVSplit

功能介绍

randomsearch是通过随机参数,对其中的每一组输入参数的组很分别进行训练,预测,评估。取得评估参数最优的模型,作为最终的返回模型
tv为训练验证,将数据按照比例切分为两份,对其中一份数据做训练,对剩余一份数据做预测和评估,得到一个评估结果。
此函数用tv方法得到每一个grid对应参数的评估结果,得到最优模型

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 默认值 | | —- | —- | —- | —- | —- | —- |

| trainRatio | 训练集比例 | 训练集与验证集的划分比例,取值范围为(0, 1]。 | Double | | 0.8 |

| ParamDist | 参数分布 | 指定搜索的参数的分布 | ParamDist | ✓ | —- |

| Estimator | Estimator | 用于调优的Estimator | Estimator | ✓ | —- |

| TuningEvaluator | 评估指标 | 用于选择最优模型的评估指标 | TuningEvaluator | ✓ | —- |

代码示例

Python 代码

  1. from pyalink.alink import *
  2. import pandas as pd
  3. useLocalEnv(1)
  4. def adult(url):
  5. data = (
  6. CsvSourceBatchOp()
  7. .setFilePath('https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv')
  8. .setSchemaStr(
  9. 'age bigint, workclass string, fnlwgt bigint,'
  10. 'education string, education_num bigint,'
  11. 'marital_status string, occupation string,'
  12. 'relationship string, race string, sex string,'
  13. 'capital_gain bigint, capital_loss bigint,'
  14. 'hours_per_week bigint, native_country string,'
  15. 'label string'
  16. )
  17. )
  18. return data
  19. def adult_train():
  20. return adult('https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv')
  21. def adult_test():
  22. return adult('https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_test.csv')
  23. def adult_numerical_feature_strs():
  24. return [
  25. "age", "fnlwgt", "education_num",
  26. "capital_gain", "capital_loss", "hours_per_week"
  27. ]
  28. def adult_categorical_feature_strs():
  29. return [
  30. "workclass", "education", "marital_status",
  31. "occupation", "relationship", "race", "sex",
  32. "native_country"
  33. ]
  34. def adult_features_strs():
  35. feature = adult_numerical_feature_strs()
  36. feature.extend(adult_categorical_feature_strs())
  37. return feature
  38. def rf_grid_search_cv(featureCols, categoryFeatureCols, label, metric):
  39. rf = (
  40. RandomForestClassifier()
  41. .setFeatureCols(featureCols)
  42. .setCategoricalCols(categoryFeatureCols)
  43. .setLabelCol(label)
  44. .setPredictionCol('prediction')
  45. .setPredictionDetailCol('prediction_detail')
  46. )
  47. paramDist = (
  48. ParamDist()
  49. .addDist(rf, 'NUM_TREES', ValueDist.randInteger(1, 10))
  50. )
  51. tuningEvaluator = (
  52. BinaryClassificationTuningEvaluator()
  53. .setLabelCol(label)
  54. .setPredictionDetailCol("prediction_detail")
  55. .setTuningBinaryClassMetric(metric)
  56. )
  57. cv = (
  58. RandomSearchCV()
  59. .setEstimator(rf)
  60. .setParamDist(paramDist)
  61. .setTuningEvaluator(tuningEvaluator)
  62. .setNumFolds(2)
  63. .enableLazyPrintTrainInfo("TrainInfo")
  64. )
  65. return cv
  66. def rf_grid_search_tv(featureCols, categoryFeatureCols, label, metric):
  67. rf = (
  68. RandomForestClassifier()
  69. .setFeatureCols(featureCols)
  70. .setCategoricalCols(categoryFeatureCols)
  71. .setLabelCol(label)
  72. .setPredictionCol('prediction')
  73. .setPredictionDetailCol('prediction_detail')
  74. )
  75. paramDist = (
  76. ParamDist()
  77. .addDist(rf, 'NUM_TREES', ValueDist.randInteger(1, 10))
  78. )
  79. tuningEvaluator = (
  80. BinaryClassificationTuningEvaluator()
  81. .setLabelCol(label)
  82. .setPredictionDetailCol("prediction_detail")
  83. .setTuningBinaryClassMetric(metric)
  84. )
  85. cv = (
  86. RandomSearchTVSplit()
  87. .setEstimator(rf)
  88. .setParamDist(paramDist)
  89. .setTuningEvaluator(tuningEvaluator)
  90. .enableLazyPrintTrainInfo("TrainInfo")
  91. )
  92. return cv
  93. def tuningcv(cv_estimator, input):
  94. return cv_estimator.fit(input)
  95. def tuningtv(tv_estimator, input):
  96. return tv_estimator.fit(input)
  97. def main():
  98. print('rf cv tuning')
  99. model = tuningcv(
  100. rf_grid_search_cv(adult_features_strs(),
  101. adult_categorical_feature_strs(), 'label', 'AUC'),
  102. adult_train()
  103. )
  104. print('rf tv tuning')
  105. model = tuningtv(
  106. rf_grid_search_tv(adult_features_strs(),
  107. adult_categorical_feature_strs(), 'label', 'AUC'),
  108. adult_train()
  109. )
  110. main()

Java 代码

  1. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  2. import com.alibaba.alink.pipeline.classification.RandomForestClassifier;
  3. import com.alibaba.alink.pipeline.tuning.BinaryClassificationTuningEvaluator;
  4. import com.alibaba.alink.pipeline.tuning.ParamDist;
  5. import com.alibaba.alink.pipeline.tuning.RandomSearchTVSplit;
  6. import com.alibaba.alink.pipeline.tuning.RandomSearchTVSplitModel;
  7. import com.alibaba.alink.pipeline.tuning.ValueDist;
  8. import org.junit.Test;
  9. public class RandomSearchTVSplitTest {
  10. @Test
  11. public void testRandomSearchTVSplit() throws Exception {
  12. String[] featureCols = new String[] {
  13. "age", "fnlwgt", "education_num",
  14. "capital_gain", "capital_loss", "hours_per_week",
  15. "workclass", "education", "marital_status",
  16. "occupation", "relationship", "race", "sex",
  17. "native_country"
  18. };
  19. String[] categoryFeatureCols = new String[] {
  20. "workclass", "education", "marital_status",
  21. "occupation", "relationship", "race", "sex",
  22. "native_country"
  23. };
  24. String label = "label";
  25. CsvSourceBatchOp data = new CsvSourceBatchOp()
  26. .setFilePath("https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv")
  27. .setSchemaStr(
  28. "age bigint, workclass string, fnlwgt bigint, education string, education_num bigint, marital_status "
  29. + "string, occupation string, relationship string, race string, sex string, capital_gain bigint, "
  30. + "capital_loss bigint, hours_per_week bigint, native_country string, label string");
  31. RandomForestClassifier rf = new RandomForestClassifier()
  32. .setFeatureCols(featureCols)
  33. .setCategoricalCols(categoryFeatureCols)
  34. .setLabelCol(label)
  35. .setPredictionCol("prediction")
  36. .setPredictionDetailCol("prediction_detail");
  37. ParamDist paramDist = new ParamDist()
  38. .addDist(rf, RandomForestClassifier.NUM_TREES, ValueDist.randInteger(1, 10));
  39. BinaryClassificationTuningEvaluator tuningEvaluator = new BinaryClassificationTuningEvaluator()
  40. .setLabelCol(label)
  41. .setPredictionDetailCol("prediction_detail")
  42. .setTuningBinaryClassMetric("AUC");
  43. RandomSearchTVSplit cv = new RandomSearchTVSplit()
  44. .setEstimator(rf)
  45. .setParamDist(paramDist)
  46. .setTuningEvaluator(tuningEvaluator)
  47. .setTrainRatio(0.8)
  48. .enableLazyPrintTrainInfo("TrainInfo");
  49. RandomSearchTVSplitModel model = cv.fit(data);
  50. }
  51. }

运行结果

TrainInfo
Metric information:
Metric name: AUC
Larger is better: true
Tuning information:

| AUC | stage | param | value | | —- | —- | —- | —- |

| 0.9121169398031198 | RandomForestClassifier | numTrees | 8 |

| 0.9105096451486404 | RandomForestClassifier | numTrees | 7 |

| 0.9105087086051442 | RandomForestClassifier | numTrees | 6 |

| 0.9098174499836453 | RandomForestClassifier | numTrees | 6 |

| 0.9089294943807537 | RandomForestClassifier | numTrees | 4 |

| 0.8910848199717841 | RandomForestClassifier | numTrees | 2 |

| 0.8862271106520978 | RandomForestClassifier | numTrees | 2 |

| 0.8748876808857913 | RandomForestClassifier | numTrees | 2 |

| 0.858989501722944 | RandomForestClassifier | numTrees | 1 |

| 0.8553973913661752 | RandomForestClassifier | numTrees | 1 |

运行结果

  1. rf cv tuning
  2. com.alibaba.alink.pipeline.tuning.GridSearchCV
  3. [ {
  4. "param" : [ {
  5. "stage" : "RandomForestClassifier",
  6. "paramName" : "numTrees",
  7. "paramValue" : 3
  8. }, {
  9. "stage" : "RandomForestClassifier",
  10. "paramName" : "subsamplingRatio",
  11. "paramValue" : 1.0
  12. } ],
  13. "metric" : 0.8922549257899725
  14. }, {
  15. "param" : [ {
  16. "stage" : "RandomForestClassifier",
  17. "paramName" : "numTrees",
  18. "paramValue" : 3
  19. }, {
  20. "stage" : "RandomForestClassifier",
  21. "paramName" : "subsamplingRatio",
  22. "paramValue" : 0.99
  23. } ],
  24. "metric" : 0.8920255970548456
  25. }, {
  26. "param" : [ {
  27. "stage" : "RandomForestClassifier",
  28. "paramName" : "numTrees",
  29. "paramValue" : 3
  30. }, {
  31. "stage" : "RandomForestClassifier",
  32. "paramName" : "subsamplingRatio",
  33. "paramValue" : 0.98
  34. } ],
  35. "metric" : 0.8944982480437225
  36. }, {
  37. "param" : [ {
  38. "stage" : "RandomForestClassifier",
  39. "paramName" : "numTrees",
  40. "paramValue" : 6
  41. }, {
  42. "stage" : "RandomForestClassifier",
  43. "paramName" : "subsamplingRatio",
  44. "paramValue" : 1.0
  45. } ],
  46. "metric" : 0.8923867598288401
  47. }, {
  48. "param" : [ {
  49. "stage" : "RandomForestClassifier",
  50. "paramName" : "numTrees",
  51. "paramValue" : 6
  52. }, {
  53. "stage" : "RandomForestClassifier",
  54. "paramName" : "subsamplingRatio",
  55. "paramValue" : 0.99
  56. } ],
  57. "metric" : 0.9012141767959505
  58. }, {
  59. "param" : [ {
  60. "stage" : "RandomForestClassifier",
  61. "paramName" : "numTrees",
  62. "paramValue" : 6
  63. }, {
  64. "stage" : "RandomForestClassifier",
  65. "paramName" : "subsamplingRatio",
  66. "paramValue" : 0.98
  67. } ],
  68. "metric" : 0.8993774036693788
  69. }, {
  70. "param" : [ {
  71. "stage" : "RandomForestClassifier",
  72. "paramName" : "numTrees",
  73. "paramValue" : 9
  74. }, {
  75. "stage" : "RandomForestClassifier",
  76. "paramName" : "subsamplingRatio",
  77. "paramValue" : 1.0
  78. } ],
  79. "metric" : 0.8981738808130779
  80. }, {
  81. "param" : [ {
  82. "stage" : "RandomForestClassifier",
  83. "paramName" : "numTrees",
  84. "paramValue" : 9
  85. }, {
  86. "stage" : "RandomForestClassifier",
  87. "paramName" : "subsamplingRatio",
  88. "paramValue" : 0.99
  89. } ],
  90. "metric" : 0.9029671873892725
  91. }, {
  92. "param" : [ {
  93. "stage" : "RandomForestClassifier",
  94. "paramName" : "numTrees",
  95. "paramValue" : 9
  96. }, {
  97. "stage" : "RandomForestClassifier",
  98. "paramName" : "subsamplingRatio",
  99. "paramValue" : 0.98
  100. } ],
  101. "metric" : 0.905228896323363
  102. } ]
  103. rf tv tuning
  104. com.alibaba.alink.pipeline.tuning.GridSearchTVSplit
  105. [ {
  106. "param" : [ {
  107. "stage" : "RandomForestClassifier",
  108. "paramName" : "numTrees",
  109. "paramValue" : 3
  110. }, {
  111. "stage" : "RandomForestClassifier",
  112. "paramName" : "subsamplingRatio",
  113. "paramValue" : 1.0
  114. } ],
  115. "metric" : 0.9022694229691741
  116. }, {
  117. "param" : [ {
  118. "stage" : "RandomForestClassifier",
  119. "paramName" : "numTrees",
  120. "paramValue" : 3
  121. }, {
  122. "stage" : "RandomForestClassifier",
  123. "paramName" : "subsamplingRatio",
  124. "paramValue" : 0.99
  125. } ],
  126. "metric" : 0.8963559966080328
  127. }, {
  128. "param" : [ {
  129. "stage" : "RandomForestClassifier",
  130. "paramName" : "numTrees",
  131. "paramValue" : 3
  132. }, {
  133. "stage" : "RandomForestClassifier",
  134. "paramName" : "subsamplingRatio",
  135. "paramValue" : 0.98
  136. } ],
  137. "metric" : 0.9041948454957178
  138. }, {
  139. "param" : [ {
  140. "stage" : "RandomForestClassifier",
  141. "paramName" : "numTrees",
  142. "paramValue" : 6
  143. }, {
  144. "stage" : "RandomForestClassifier",
  145. "paramName" : "subsamplingRatio",
  146. "paramValue" : 1.0
  147. } ],
  148. "metric" : 0.8982021117392784
  149. }, {
  150. "param" : [ {
  151. "stage" : "RandomForestClassifier",
  152. "paramName" : "numTrees",
  153. "paramValue" : 6
  154. }, {
  155. "stage" : "RandomForestClassifier",
  156. "paramName" : "subsamplingRatio",
  157. "paramValue" : 0.99
  158. } ],
  159. "metric" : 0.9031851535310546
  160. }, {
  161. "param" : [ {
  162. "stage" : "RandomForestClassifier",
  163. "paramName" : "numTrees",
  164. "paramValue" : 6
  165. }, {
  166. "stage" : "RandomForestClassifier",
  167. "paramName" : "subsamplingRatio",
  168. "paramValue" : 0.98
  169. } ],
  170. "metric" : 0.9034443322241488
  171. }, {
  172. "param" : [ {
  173. "stage" : "RandomForestClassifier",
  174. "paramName" : "numTrees",
  175. "paramValue" : 9
  176. }, {
  177. "stage" : "RandomForestClassifier",
  178. "paramName" : "subsamplingRatio",
  179. "paramValue" : 1.0
  180. } ],
  181. "metric" : 0.8993474753000145
  182. }, {
  183. "param" : [ {
  184. "stage" : "RandomForestClassifier",
  185. "paramName" : "numTrees",
  186. "paramValue" : 9
  187. }, {
  188. "stage" : "RandomForestClassifier",
  189. "paramName" : "subsamplingRatio",
  190. "paramValue" : 0.99
  191. } ],
  192. "metric" : 0.9090250137144916
  193. }, {
  194. "param" : [ {
  195. "stage" : "RandomForestClassifier",
  196. "paramName" : "numTrees",
  197. "paramValue" : 9
  198. }, {
  199. "stage" : "RandomForestClassifier",
  200. "paramName" : "subsamplingRatio",
  201. "paramValue" : 0.98
  202. } ],
  203. "metric" : 0.9129786771786127
  204. } ]