A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

前言:在学机器学习的过程中,我曾经不是很看重用纯手工来实现算法,认为都有现成的库可以给你调用了,你还自己来实现算法,不仅效率低,而且中间环节出问题了,十分的烦人,调试起来,难度比较大,但为什么我还要这样来做,我个人觉得如果你只是会简单的调用一下模型,这可能只是入门级的,要真正的理解算法,只有自己亲手动手实现,抛开现成的算法库,这样提高才会更大,好了,写了这么多的废话,下面开始吧。

这个案例是来自机器学习实战中的逻辑回归的案例--从疝气症病预测病马的死忙率,其中这个算法的核心是计算出各特征前的系数,即w,利用的方法是梯度上升法。

# !/usr/bin/env python
# -*- coding:utf-8 -*-
# author:lxy

from numpy import *

# 定义sigmoid函数
def sigmoid(inX):
    return 1.0/(1+exp(-inX))

# 穿插,转置前要保证对象是mat(矩阵)格式,不然效果无效
# weights = ones(10)行向量
# wgh = weights.T 依然是行向量
# weights==wgh
# wgh1 = mat(weights).T  此时才是列向量

# 定义梯度上升的算法,这里是改进的随机梯度,alpha是会改变的
def gradAscent(dataMatrix,classLables,iterNum):
    m,n = shape(dataMatrix)
    weights = ones(n)
    for j in range(iterNum):
        dataIndex = list(range(m))
        for i in range(m):
            alpha = 4/(1.0+i+j)+0.0001
            randIndex = int(random.uniform(0,len(dataIndex)))
            h = sigmoid(sum(dataMatrix[randIndex]*weights))
            error  = classLables[randIndex]-h
            weights = weights+alpha*error*dataMatrix[randIndex]
            del dataIndex[randIndex]
    return weights

# 定义分类的结果
def classifier(inX,weights):
    prob = sigmoid(sum(inX*weights))
    if prob>0.5:
        return 1.0
    else:
        return 0.0

def colicTest():
    frtrain = open('./horseColicTraining.txt')
    frtest = open('./horseColicTest.txt')
    trainingSet = []
    trainingLables = []
    for line in frtrain.readlines():
        currline = line.strip().split('        ')
        lineArr = []
        for i in range(21):
            lineArr.append(float(currline))
        trainingSet.append(lineArr)
        trainingLables.append(float(currline[21]))
    trainwieghts = gradAscent(array(trainingSet),trainingLables,500)
    print("权重向量:",trainwieghts)
    errorcount = 0;numTest = 0
    for line in frtest.readlines():
        numTest+=1
        currLine = line.strip().split('        ')
        lineArr = []
        for i in range(21):
            lineArr.append(float(currLine))
        if int(classifier(array(lineArr),trainwieghts))!=int(currLine[21]):
            errorcount+=1
    errorrate = errorcount/numTest
    print("the error rate of  this test is:%f"%errorrate)
    return errorrate

# 这里尝试多次训练,由于每次得到的weights不是恒定不变的
def multitest():
    numtime = 10;errorSum = 0
    for k in range(numtime):
        errorSum+=colicTest()
    print("after %d iterations the average error rate is:%f"%(numtime,errorSum/numtime))

multitest()

运行结果:

RuntimeWarning: overflow encountered in exp
  return 1.0/(1+exp(-inX))
权重向量: [  1.64792983e+01  -2.70878466e+00   1.83426151e+00  -1.61000262e+00
   1.09190057e+00  -6.60004604e+00   5.55888286e+00  -7.10470009e+00
  -8.51425980e+00  -9.03254657e+00   2.07403014e+01  -2.27597071e+01
   2.77227297e+01   1.05961276e+01  -1.12490393e+01   4.15837596e+00
  -3.85097874e+00   4.69931110e-03  -2.76205733e-02  -4.86466432e+00
  -2.14334375e+00]
the error rate of  this test is:0.343284
权重向量: [  1.69696071e+01  -1.74080618e+00   1.87041830e+00  -1.25709650e+00
   8.66389497e-01  -5.33353787e+00   4.98364275e+00  -7.01733212e+00
  -9.93414215e+00  -9.06290710e+00   2.12621011e+01  -2.45597883e+01
   2.74794228e+01   1.24443840e+01  -1.18094989e+01   4.10677322e+00
  -3.91321041e+00  -1.14931186e-01  -1.52230207e-02  -4.99307328e+00
  -2.12058899e+00]
the error rate of  this test is:0.328358
权重向量: [ 15.41657995  -1.7775319    1.22496139  -1.37359964   0.97737564
  -6.09300482   4.60908661  -7.7380833   -6.78073168  -9.14229183
  20.08345854 -23.19000154  25.8841097   11.16700552 -11.16666343
   3.94740454  -3.29618226   0.34338709   0.40671626  -6.01183123
  -1.92945607]
the error rate of  this test is:0.328358
权重向量: [ 16.01083293  -2.66867704   1.44619268  -1.4758387    1.23934799
  -5.07424509   3.99105886  -6.72766373  -7.85744122  -8.26270927
  19.44129549 -24.03073301  27.64012095  10.37589911 -10.5526026
   3.79551252  -3.68283135   0.13239338  -0.35005111  -6.34665997
  -2.1934238 ]
the error rate of  this test is:0.417910
权重向量: [ 13.73345524  -2.04977492   1.43323051  -1.87623337   0.94828083
  -4.73969282   4.45417792  -7.57892025  -7.54221517  -8.3387618
  20.21284442 -23.87435702  26.93161674  11.3585979  -11.72080447
   3.70013146  -4.24441599   0.45971015  -0.27127358  -5.33174129
  -2.11152087]
the error rate of  this test is:0.402985
权重向量: [ 16.4536017   -2.71192141   1.51957008  -1.03381621   0.8125723
  -5.07876785   4.4157913   -7.23451982  -6.11025706  -9.21167923
  20.86843755 -23.16594286  27.20612196  10.73479633 -10.47205887
   4.24299353  -3.54023733  -0.2715715   -0.59786935  -4.33051587
  -2.30642361]
the error rate of  this test is:0.388060
权重向量: [ 16.51313194  -2.49072519   2.04414683  -1.14788396   0.92670114
  -5.26760991   3.85309816  -7.21619825  -9.11456637  -8.63518754
  20.60040907 -23.94924438  27.93833898  10.76430779 -11.31338169
   4.6565549   -3.83701588  -0.4700557   -0.03909479  -5.5659964
  -2.66949478]
the error rate of  this test is:0.343284
权重向量: [ 18.55529534  -1.87912569   1.83719755  -1.52667051   1.37774381
  -4.89643812   3.86031024  -7.12933551  -8.484749    -7.8020516
  19.70134522 -24.15868032  27.33220816  11.75108273 -11.42585573
   3.48514573  -3.37322907  -0.10790695   0.47341688  -5.27229335
  -1.62493064]
the error rate of  this test is:0.343284
权重向量: [ 14.92840677  -2.25301854   1.91621959  -1.57895156   1.10141376
  -5.39616131   5.23520837  -6.77424035  -9.3037617   -8.91582214
  20.88464734 -24.57218471  27.23249816  12.42044801 -11.71615474
   3.38022571  -2.8415815   -0.71553992   0.53838514  -4.63944673
  -2.46458178]
the error rate of  this test is:0.388060
权重向量: [ 16.79923572  -2.1115749    1.35687685  -1.22022731   1.42524975
  -5.28729152   4.28135548  -6.55468089  -7.73119328  -9.15710281
  20.16741212 -22.87745152  28.19544042  10.42505585 -10.16404172
   3.75571354  -3.7527875   -0.4665408    0.04066664  -4.72233482
  -2.31501088]
the error rate of  this test is:0.373134
after 10 iterations the average error rate is:0.365672

Process finished with exit code 0

---------------------
【转载】
作者:不曾走远~
原文:https://blog.csdn.net/qq_20412595/article/details/82501280


2 个回复

倒序浏览
回复 使用道具 举报
回复 使用道具 举报
您需要登录后才可以回帖 登录 | 加入黑马