from os import listdir
from numpy import *
import time
import operator
def classify(inputPoint,dataSet,labels,k):
dataSetSize = dataSet.shape[0] #已知分类的数据集(训练集)的行数
#先tile函数将输入点拓展成与训练集相同维数的矩阵,再计算欧氏距离
diffMat = tile(inputPoint,(dataSetSize,1))-dataSet #样本与训练集的差值矩阵
sqDiffMat = diffMat ** 2 #差值矩阵平方
sqDistances = sqDiffMat.sum(axis=1) #计算每一行上元素的和
distances = sqDistances ** 0.5 #开方得到欧拉距离矩阵
sortedDistIndicies = distances.argsort() #按distances中元素进行升序排序后得到的对应下标的列表
#选择距离最小的k个点
classCount = {}
for i in range(k):
voteIlabel = labels[ sortedDistIndicies ]
classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
#按classCount字典的第2个元素(即类别出现的次数)从大到小排序
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
#文本向量化 32x32 -> 1x1024
def img2vecor(filename):
returnVect=[]
fr=open(filename)
for i in range(32):
lineStr=fr.readline()
for j in range(32):
returnVect.append(int(lineStr[j]))
return returnVect
#从文件名中解析分类数字
def classnumCut(filename):
fileStr=filename.split('.')[0]
classNumStr=int(fileStr.split('_')[0])
return classNumStr
#构建训练集数据向量,及对应分类标签向量
def trainingDataSet(file):
hwLabels=[]
trainingFileList=listdir(file) #获取目录内容
m=len(trainingFileList)
trainingMat=zeros((m,1024)) #获取m维向量的训练集
for i in range(m):
fileNameStr=trainingFileList
hwLabels.append(classnumCut(fileNameStr))
trainingMat[i,:]=img2vecor(file+'/%s' % fileNameStr)
return hwLabels,trainingMat
def handwritingTest(file):
hwLabels,trainingMat = trainingDataSet('/root/python_test/data/data/trainingDigits') #构建训练集
testFileList = listdir(file) #获取测试集
errorCount = 0.0 #错误数
mTest = len(testFileList) #测试集总样本数
t1 = time.time()
for i in range(mTest):
fileNameStr = testFileList
classNumStr = classnumCut(fileNameStr)
vectorUnderTest=img2vecor(file+'/%s' % fileNameStr)
classifierResult =classify(vectorUnderTest,trainingMat,hwLabels, 3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr):
errorCount += 1.0
print ("\nthe total number of tests is: %d" % mTest) #输出测试总样本数
print ("the total number of errors is: %d" % errorCount) #输出测试错误样本数
print("the total error rate is: %f" % (errorCount/float(mTest))) #输出错误率
t2 = time.time()
print("Cost time: %.2fmin, %.4fs."%((t2-t1)//60,(t2-t1)%60) ) #测试耗时
if __name__ == "__main__":
handwritingTest('/root/python_test/data/data/testDigits')
---------------------
【转载】仅作分享,侵删