为给定文档的特征值(频数统计,预测文档提供),
为文档 ,则
每个文档类别的概率(某文档类别词数/总文档词数)
给定类别下特征(被预测文档中出现的词)的概率
- 特征
有特征词
- 特征
- 计算方法
(训练文档中去计算)
为
词在
类别所有文档中出现的次数
为所属类别
下的文档所有词出现的次数和
- 拉普拉斯平滑系数
如果词频列表里面 有很多出现次数都为 0,很可能计算结果都为零
拉普拉斯平滑系数一般为1, m 为训练文档中统计出的特征词个数
- 文本分类实例
- 加载数据集 ```python news = fetch20newsgroups(subset=’all’, data_home=’data’) #subset: ‘train’或者’test’,’all’,可选,选择要加载的数据集,fetch*的文件较大,所以需要下载,data_home是下载路径 print(news.target) print(news.target_names)
结果: [10 3 17 … 3 1 7] [‘alt.atheism’, ‘comp.graphics’, ‘comp.os.ms-windows.misc’, ‘comp.sys.ibm.pc.hardware’, ‘comp.sys.mac.hardware’, ‘comp.windows.x’, ‘misc.forsale’, ‘rec.autos’, ‘rec.motorcycles’, ‘rec.sport.baseball’, ‘rec.sport.hockey’, ‘sci.crypt’, ‘sci.electronics’, ‘sci.med’, ‘sci.space’, ‘soc.religion.christian’, ‘talk.politics.guns’, ‘talk.politics.mideast’, ‘talk.politics.misc’, ‘talk.religion.misc’]
b. 划分训练集和测试集,特征提取```python# 进行数据分割x_train, x_test, y_train, y_test = train_test_split(news.data, news.target, test_size=0.25, random_state=1)# 对数据集进行特征抽取tf = TfidfVectorizer()x_train = tf.fit_transform(x_train)x_test = tf.transform(x_test)
c. 朴素贝叶斯预测
# 进行朴素贝叶斯算法的预测,alpha是拉普拉斯平滑系数,分子和分母加上一个系数,分母加alpha*特征词数目mlt = MultinomialNB(alpha=1.0)mlt.fit(x_train, y_train)y_predict = mlt.predict(x_test)print("预测的文章类别为:", y_predict)print("准确率为:", mlt.score(x_test, y_test))print("每个类别的精确率和召回率:", classification_report(y_test, y_predict, target_names=news.target_names))结果:预测的文章类别为: [16 19 18 ... 13 7 14]准确率为: 0.8518675721561969每个类别的精确率和召回率:precision recall f1-score supportalt.atheism 0.91 0.77 0.83 199comp.graphics 0.83 0.79 0.81 242comp.os.ms-windows.misc 0.89 0.83 0.86 263comp.sys.ibm.pc.hardware 0.80 0.83 0.81 262comp.sys.mac.hardware 0.90 0.88 0.89 234comp.windows.x 0.92 0.85 0.88 230misc.forsale 0.96 0.67 0.79 257rec.autos 0.90 0.87 0.88 265rec.motorcycles 0.90 0.95 0.92 251rec.sport.baseball 0.89 0.96 0.93 226rec.sport.hockey 0.95 0.98 0.96 262sci.crypt 0.76 0.97 0.85 257sci.electronics 0.84 0.80 0.82 229sci.med 0.97 0.86 0.91 249sci.space 0.92 0.96 0.94 256soc.religion.christian 0.55 0.98 0.70 243talk.politics.guns 0.76 0.96 0.85 234talk.politics.mideast 0.93 0.99 0.96 224talk.politics.misc 0.98 0.56 0.72 197talk.religion.misc 0.97 0.26 0.41 132accuracy 0.85 4712macro avg 0.88 0.84 0.84 4712weighted avg 0.87 0.85 0.85 4712
d. 计算AUC
# 把0-19总计20个分类,变为0和1y_test1 = np.where(y_test == 5, 1, 0)y_predict1 = np.where(y_predict == 5, 1, 0)# roc_auc_score的y_test只能是二分类,针对多分类如何计算AUCprint("AUC指标:", roc_auc_score(y_test1, y_predict1))
