准备数据
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
# 载入数据
digits = datasets.load_digits()
X = digits.data
y = digits.target
# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)
网格搜索语法
参数设置
from sklearn.model_selection import GridSearchCV
# 设置两组参数
param_grid = [
{
'weights': ['uniform'],
'n_neighbors': [i for i in range(1, 11)]
},
{
'weights': ['distance'],
'n_neighbors': [i for i in range(1, 11)],
'p': [i for i in range(1, 6)]
}
]
建模
knn_clf = KNeighborsClassifier()
from sklearn.model_selection import GridSearchCV
grid_search = GridSearchCV(knn_clf, param_grid)
grid_search.fit(X_train, y_train)
更多超参数
n_jobs
: 代表CPU核数,-1
代表使用所有核