基本概念

过拟合

  • 在模型学习数据普遍化模式的过程中,它还学习了每个样本特有的噪音特征,这样的模型称为过拟合。

  • 过拟合的模型可能在原数据集上的表现非常好,但是泛化能力很差,也就是换一个数据集表现就很差,这就是由于过拟合导致的。

模型调优

  • 几乎所有的预测模型方法都含有调优参数,用现有的数据来调整这些参数,从而给出最好的预测,这个过程称为模型调优。

数据分割

  • 数据量较小时应避免划分测试集。

  • 如果某类的样本量明显少于其他类,那么简单的随机划分会导致训练集和测试集结果大相径庭,应使用分层随机抽样法

重抽样技术

  • K折交叉验证
  • 广义交叉验证
  • 重复训练/测试集划分
  • Bootstrap方法

数据划分建议

  • 样本量较少,笔者建议使用10折交叉验证
  • 如果目标不是得到最好的模型表现的估计,而是在几个不同的模型中进行选择,那么最好使用Bootstrap方法

计算

  1. ## 加载R包和数据
  2. library(AppliedPredictiveModeling)
  3. data(twoClassData)
  4. str(predictors)
  5. ## 'data.frame': 208 obs. of 2 variables:
  6. ## $ PredictorA: num 0.158 0.655 0.706 0.199 0.395 ...
  7. ## $ PredictorB: num 0.1609 0.4918 0.6333 0.0881 0.4152 ...
  8. str(classes)
  9. ## Factor w/ 2 levels "Class1","Class2": 2 2 2 2 2 2 2 2 2 2 ...
  10. set.seed(1)
  1. # 划分数据集
  2. library(caret)
  3. ## 载入需要的程辑包:lattice
  4. ## 载入需要的程辑包:ggplot2
  5. trainingRows <- createDataPartition(classes, p = 0.8, list = F) # 也可进行分层抽样
  6. head(trainingRows)
  7. ## Resample1
  8. ## [1,] 1
  9. ## [2,] 2
  10. ## [3,] 3
  11. ## [4,] 7
  12. ## [5,] 8
  13. ## [6,] 9

变为训练集和测试集:

  1. trainPredictors <- predictors[trainingRows, ]
  2. trainClasses <- classes[trainingRows]
  3. testPredictors <- predictors[-trainingRows, ]
  4. testClasses <- classes[-trainingRows]
  5. str(trainPredictors)
  6. ## 'data.frame': 167 obs. of 2 variables:
  7. ## $ PredictorA: num 0.1582 0.6552 0.706 0.0658 0.3086 ...
  8. ## $ PredictorB: num 0.161 0.492 0.633 0.179 0.28 ...
  9. str(testPredictors)
  10. ## 'data.frame': 41 obs. of 2 variables:
  11. ## $ PredictorA: num 0.1992 0.3952 0.425 0.0847 0.2909 ...
  12. ## $ PredictorB: num 0.0881 0.4152 0.2988 0.0548 0.3021 ...
  1. ## 重抽样
  2. set.seed(1)
  3. repeatedSplits <- createDataPartition(trainClasses, p = 0.8, times = 3)
  4. str(repeatedSplits)
  5. ## List of 3
  6. ## $ Resample1: int [1:135] 1 2 3 4 6 7 9 10 11 12 ...
  7. ## $ Resample2: int [1:135] 1 2 3 4 5 6 7 9 10 11 ...
  8. ## $ Resample3: int [1:135] 1 2 3 4 5 7 8 9 11 12 ...
  1. ## K折交叉验证
  2. set.seed(1)
  3. cvSplits <- createFolds(trainClasses, k = 10, returnTrain = T)
  4. str(cvSplits)
  5. ## List of 10
  6. ## $ Fold01: int [1:150] 1 2 4 5 6 7 8 10 11 13 ...
  7. ## $ Fold02: int [1:150] 1 2 3 4 6 7 8 9 10 11 ...
  8. ## $ Fold03: int [1:150] 1 3 4 5 6 7 8 9 10 11 ...
  9. ## $ Fold04: int [1:150] 1 2 3 4 5 6 7 8 9 10 ...
  10. ## $ Fold05: int [1:150] 2 3 4 5 6 7 8 9 10 11 ...
  11. ## $ Fold06: int [1:150] 1 2 3 4 5 6 7 8 9 11 ...
  12. ## $ Fold07: int [1:150] 1 2 3 4 5 6 7 9 10 12 ...
  13. ## $ Fold08: int [1:151] 1 2 3 4 5 6 8 9 10 11 ...
  14. ## $ Fold09: int [1:151] 1 2 3 5 6 7 8 9 10 11 ...
  15. ## $ Fold10: int [1:151] 1 2 3 4 5 7 8 9 10 11 ...
  16. fold1 <- cvSplits[[1]] # 第一折的行号
  17. cvPredictors1 <- trainPredictors[fold1, ] # 得到第一份90%的样本
  18. cvClass1 <- trainClasses[fold1]
  19. nrow(trainPredictors)
  20. ## [1] 167
  21. nrow(cvPredictors1)
  22. ## [1] 150
  1. ## R基础建模
  2. ## 训练
  3. trainPredictors <- as.matrix(trainPredictors)
  4. knnFit <- knn3(x = trainPredictors, y = trainClasses, k = 5)
  5. knnFit
  6. ## 5-nearest neighbor model
  7. ## Training set outcome distribution:
  8. ##
  9. ## Class1 Class2
  10. ## 89 78
  11. ## 预测
  12. testPredictions <- predict(knnFit, newdata = testPredictors, type = "class")
  13. head(testPredictions)
  14. ## [1] Class2 Class1 Class1 Class2 Class1 Class2
  15. ## Levels: Class1 Class2
  16. str(testPredictions)
  17. ## Factor w/ 2 levels "Class1","Class2": 2 1 1 2 1 2 2 1 2 2 ...
  1. ## 决定调优参数
  2. library(caret)
  3. data("GermanCredit")
  4. set.seed(1056)
  5. svmFit <- train(Class ~.,
  6. data = GermanCredit,
  7. method = "svmRadial")
  8. ## 进行预处理,并使用重复5折交叉验证
  9. set.seed(1056)
  10. svmfit <- train(Class ~.,
  11. data = GermanCredit,
  12. method = "svmRadial",
  13. preProc = c("center" ,"scale"),
  14. tuneLength = 10,
  15. trControl = trainControl(method = "repeatedcv", repeats = 5)
  16. ) # 其实这个函数我感觉比现在的tidymodels和mlr3的写法都要简洁...
  17. svmfit
  18. ## Support Vector Machines with Radial Basis Function Kernel
  19. ##
  20. ## 1000 samples
  21. ## 61 predictor
  22. ## 2 classes: 'Bad', 'Good'
  23. ##
  24. ## Pre-processing: centered (61), scaled (61)
  25. ## Resampling: Cross-Validated (10 fold, repeated 5 times)
  26. ## Summary of sample sizes: 900, 900, 900, 900, 900, 900, ...
  27. ## Resampling results across tuning parameters:
  28. ##
  29. ## C Accuracy Kappa
  30. ## 0.25 0.7040 0.01934723
  31. ## 0.50 0.7430 0.24527603
  32. ## 1.00 0.7610 0.35046362
  33. ## 2.00 0.7628 0.38285072
  34. ## 4.00 0.7610 0.39239970
  35. ## 8.00 0.7616 0.40357861
  36. ## 16.00 0.7542 0.39860268
  37. ## 32.00 0.7418 0.37677389
  38. ## 64.00 0.7344 0.36165095
  39. ## 128.00 0.7348 0.36361822
  40. ##
  41. ## Tuning parameter 'sigma' was held constant at a value of 0.009718427
  42. ## Accuracy was used to select the optimal model using the largest value.
  43. ## The final values used for the model were sigma = 0.009718427 and C = 2.
  44. plot(svmfit, scales = list(x=list(log = 2)))

应用预测建模2:过度拟合和模型调优 - 图1

  1. ## 比较模型
  2. set.seed(1056)
  3. logistic <- train(Class ~.,
  4. data = GermanCredit,
  5. method = "glm",
  6. trControl = trainControl(method = "repeatedcv", repeats = 5)
  7. )
  8. logistic
  9. ## Generalized Linear Model
  10. ##
  11. ## 1000 samples
  12. ## 61 predictor
  13. ## 2 classes: 'Bad', 'Good'
  14. ##
  15. ## No pre-processing
  16. ## Resampling: Cross-Validated (10 fold, repeated 5 times)
  17. ## Summary of sample sizes: 900, 900, 900, 900, 900, 900, ...
  18. ## Resampling results:
  19. ##
  20. ## Accuracy Kappa
  21. ## 0.749 0.3661277
  22. resamp <- resamples(list(svm = svmfit, logi = logistic))
  23. summary(resamp)
  24. ##
  25. ## Call:
  26. ## summary.resamples(object = resamp)
  27. ##
  28. ## Models: svm, logi
  29. ## Number of resamples: 50
  30. ##
  31. ## Accuracy
  32. ## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
  33. ## svm 0.69 0.7425 0.77 0.7628 0.7800 0.84 0
  34. ## logi 0.65 0.7200 0.75 0.7490 0.7775 0.88 0
  35. ##
  36. ## Kappa
  37. ## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
  38. ## svm 0.1944444 0.3385694 0.3882979 0.3828507 0.4293478 0.5959596 0
  39. ## logi 0.1581633 0.2993889 0.3779762 0.3661277 0.4240132 0.7029703 0
  40. summary(diff(resamp))
  41. ##
  42. ## Call:
  43. ## summary.diff.resamples(object = diff(resamp))
  44. ##
  45. ## p-value adjustment: bonferroni
  46. ## Upper diagonal: estimates of the difference
  47. ## Lower diagonal: p-value for H0: difference = 0
  48. ##
  49. ## Accuracy
  50. ## svm logi
  51. ## svm 0.0138
  52. ## logi 0.0002436
  53. ##
  54. ## Kappa
  55. ## svm logi
  56. ## svm 0.01672
  57. ## logi 0.07449