前言:之前在朴素贝叶斯算法这篇文章中,对朴素贝叶斯分类算法的原理做了一个总结。这里我们就从实战的角度来看朴素贝叶斯类库。重点讲述朴素贝叶斯类库的使用要点和参数选择。
这里的实战是利用朴素贝叶斯来进行垃圾邮件的分类
实战部分:
from numpy import *
import codecs,re
from sklearn.naive_bayes import MultinomialNB
#创建一个包含在所有文档中出现的不重复词的列表
def createVocabList(dadaSet):
vocabset = set([])
for document in dadaSet:
vocabset = vocabset | set(document)
return list(vocabset)
def setOfWords2Vec(vocabList, inputSet):
returnVec = [0]*len(vocabList)
for word in inputSet:
if word in vocabList:
returnVec[vocabList.index(word)]=1
else:
print("the word:%s is not in my Vocabulary"%word)
return returnVec
def bagOfWords2VecMN(vocabList, inputSet):
returnVec = [0]*len(vocabList)
for word in inputSet:
if word in vocabList:
returnVec[vocabList.index(word)] += 1
return returnVec
def textParse(bigString): #input is big string, #output is word list
listOfTokens = re.split(r'\W*',bigString)
return [tok.lower() for tok in listOfTokens if len(tok)>2]
# 这里是重点
def spamtest():
clf = MultinomialNB()
doclist =[]
classlist = []
fulllist = []
for i in range(1,26):
open_res = codecs.open('./spam/%d.txt'% i,mode = 'r',encoding = 'utf-8').read()
# print(open_res)
wordlist = textParse(open_res)
# print(wordlist)
doclist.append(wordlist)
fulllist.extend(wordlist)
classlist.append(1)
open_rs = codecs.open('./ham/%d.txt'% i,mode = 'r',encoding = 'utf-8').read()
wordlist = textParse(open_rs)
doclist.append(wordlist)
fulllist.extend(wordlist)
classlist.append(0)
vocabList = createVocabList(doclist)
trainSet = list(range(50))
testSet = []
#拆分训练集与测试集,其中测试集选出10个样本,从总的样本中抽选出10个样本,注意抽取的方法,一定注意随机的思想
for i in range(10):
randindex = int(random.uniform(0,len(trainSet))) #这里是抽取的trainSet的索引的意思
testSet.append(trainSet[randindex])
trainSet.pop(randindex)
trainMat = []
trainClasses = []
for docindex in trainSet:
trainMat.append(bagOfWords2VecMN(vocabList,doclist[docindex]))
trainClasses.append(classlist[docindex])
#模型的训练
clf.fit(array(trainMat),array(trainClasses))
#拿切分出来的样本来测试
errorCount = 0
for testdocindex in testSet:
testVecs = []
testVecs.append(bagOfWords2VecMN(vocabList,doclist[testdocindex]))
if clf.predict(array(testVecs))!=classlist[testdocindex]:
errorCount +=1
print("error rate is :",(errorCount/10))
spamtest()
程序运行结果:
FutureWarning: split() requires a non-empty pattern match.return _compile(pattern, flags).split(string, maxsplit)
error rate is : 0.0
Process finished with exit code 0
---------------------
【转载】
作者:不曾走远~
原文:https://blog.csdn.net/qq_20412595/article/details/82467042
|
|