A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

from os import listdir
from numpy import *
import time
import operator
def classify(inputPoint,dataSet,labels,k):
    dataSetSize = dataSet.shape[0]     #已知分类的数据集(训练集)的行数
    #先tile函数将输入点拓展成与训练集相同维数的矩阵,再计算欧氏距离
    diffMat = tile(inputPoint,(dataSetSize,1))-dataSet  #样本与训练集的差值矩阵
    sqDiffMat = diffMat ** 2                    #差值矩阵平方
    sqDistances = sqDiffMat.sum(axis=1)         #计算每一行上元素的和
    distances = sqDistances ** 0.5              #开方得到欧拉距离矩阵
    sortedDistIndicies = distances.argsort()    #按distances中元素进行升序排序后得到的对应下标的列表
    #选择距离最小的k个点
    classCount = {}
    for i in range(k):
        voteIlabel = labels[ sortedDistIndicies ]
        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
    #按classCount字典的第2个元素(即类别出现的次数)从大到小排序
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]
#文本向量化 32x32 -> 1x1024
def img2vecor(filename):
    returnVect=[]
    fr=open(filename)
    for i in range(32):
        lineStr=fr.readline()
        for j in range(32):
            returnVect.append(int(lineStr[j]))
    return returnVect
#从文件名中解析分类数字
def classnumCut(filename):
    fileStr=filename.split('.')[0]
    classNumStr=int(fileStr.split('_')[0])
    return classNumStr
#构建训练集数据向量,及对应分类标签向量
def trainingDataSet(file):
    hwLabels=[]
    trainingFileList=listdir(file) #获取目录内容
    m=len(trainingFileList)
    trainingMat=zeros((m,1024))  #获取m维向量的训练集
    for i in range(m):
        fileNameStr=trainingFileList
        hwLabels.append(classnumCut(fileNameStr))
        trainingMat[i,:]=img2vecor(file+'/%s' % fileNameStr)
    return hwLabels,trainingMat

def handwritingTest(file):
    hwLabels,trainingMat = trainingDataSet('/root/python_test/data/data/trainingDigits')    #构建训练集
    testFileList = listdir(file)        #获取测试集
    errorCount = 0.0                            #错误数
    mTest = len(testFileList)                   #测试集总样本数
    t1 = time.time()
    for i in range(mTest):
        fileNameStr = testFileList
        classNumStr = classnumCut(fileNameStr)
        vectorUnderTest=img2vecor(file+'/%s' % fileNameStr)
        classifierResult =classify(vectorUnderTest,trainingMat,hwLabels, 3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr):
            errorCount += 1.0
        print ("\nthe total number of tests is: %d" % mTest)               #输出测试总样本数
        print ("the total number of errors is: %d" % errorCount)           #输出测试错误样本数
        print("the total error rate is: %f" % (errorCount/float(mTest)))  #输出错误率
    t2 = time.time()
    print("Cost time: %.2fmin, %.4fs."%((t2-t1)//60,(t2-t1)%60) )     #测试耗时

if __name__ == "__main__":
    handwritingTest('/root/python_test/data/data/testDigits')
---------------------
作者:wx_411180165
来源:CSDN
原文:https://blog.csdn.net/qq_24726509/article/details/84923274
版权声明:本文为博主原创文章,转载请附上博文链接!

1 个回复

倒序浏览
奈斯
回复 使用道具 举报
您需要登录后才可以回帖 登录 | 加入黑马