前言:在学机器学习的过程中,我曾经不是很看重用纯手工来实现算法,认为都有现成的库可以给你调用了,你还自己来实现算法,不仅效率低,而且中间环节出问题了,十分的烦人,调试起来,难度比较大,但为什么我还要这样来做,我个人觉得如果你只是会简单的调用一下模型,这可能只是入门级的,要真正的理解算法,只有自己亲手动手实现,抛开现成的算法库,这样提高才会更大,好了,写了这么多的废话,下面开始吧。
这个案例是来自机器学习实战中的逻辑回归的案例--从疝气症病预测病马的死忙率,其中这个算法的核心是计算出各特征前的系数,即w,利用的方法是梯度上升法。
# !/usr/bin/env python
# -*- coding:utf-8 -*-
# author:lxy
from numpy import *
# 定义sigmoid函数
def sigmoid(inX):
return 1.0/(1+exp(-inX))
# 穿插,转置前要保证对象是mat(矩阵)格式,不然效果无效
# weights = ones(10)行向量
# wgh = weights.T 依然是行向量
# weights==wgh
# wgh1 = mat(weights).T 此时才是列向量
# 定义梯度上升的算法,这里是改进的随机梯度,alpha是会改变的
def gradAscent(dataMatrix,classLables,iterNum):
m,n = shape(dataMatrix)
weights = ones(n)
for j in range(iterNum):
dataIndex = list(range(m))
for i in range(m):
alpha = 4/(1.0+i+j)+0.0001
randIndex = int(random.uniform(0,len(dataIndex)))
h = sigmoid(sum(dataMatrix[randIndex]*weights))
error = classLables[randIndex]-h
weights = weights+alpha*error*dataMatrix[randIndex]
del dataIndex[randIndex]
return weights
# 定义分类的结果
def classifier(inX,weights):
prob = sigmoid(sum(inX*weights))
if prob>0.5:
return 1.0
else:
return 0.0
def colicTest():
frtrain = open('./horseColicTraining.txt')
frtest = open('./horseColicTest.txt')
trainingSet = []
trainingLables = []
for line in frtrain.readlines():
currline = line.strip().split(' ')
lineArr = []
for i in range(21):
lineArr.append(float(currline))
trainingSet.append(lineArr)
trainingLables.append(float(currline[21]))
trainwieghts = gradAscent(array(trainingSet),trainingLables,500)
print("权重向量:",trainwieghts)
errorcount = 0;numTest = 0
for line in frtest.readlines():
numTest+=1
currLine = line.strip().split(' ')
lineArr = []
for i in range(21):
lineArr.append(float(currLine))
if int(classifier(array(lineArr),trainwieghts))!=int(currLine[21]):
errorcount+=1
errorrate = errorcount/numTest
print("the error rate of this test is:%f"%errorrate)
return errorrate
# 这里尝试多次训练,由于每次得到的weights不是恒定不变的
def multitest():
numtime = 10;errorSum = 0
for k in range(numtime):
errorSum+=colicTest()
print("after %d iterations the average error rate is:%f"%(numtime,errorSum/numtime))
multitest()
运行结果:
RuntimeWarning: overflow encountered in exp
return 1.0/(1+exp(-inX))
权重向量: [ 1.64792983e+01 -2.70878466e+00 1.83426151e+00 -1.61000262e+00
1.09190057e+00 -6.60004604e+00 5.55888286e+00 -7.10470009e+00
-8.51425980e+00 -9.03254657e+00 2.07403014e+01 -2.27597071e+01
2.77227297e+01 1.05961276e+01 -1.12490393e+01 4.15837596e+00
-3.85097874e+00 4.69931110e-03 -2.76205733e-02 -4.86466432e+00
-2.14334375e+00]
the error rate of this test is:0.343284
权重向量: [ 1.69696071e+01 -1.74080618e+00 1.87041830e+00 -1.25709650e+00
8.66389497e-01 -5.33353787e+00 4.98364275e+00 -7.01733212e+00
-9.93414215e+00 -9.06290710e+00 2.12621011e+01 -2.45597883e+01
2.74794228e+01 1.24443840e+01 -1.18094989e+01 4.10677322e+00
-3.91321041e+00 -1.14931186e-01 -1.52230207e-02 -4.99307328e+00
-2.12058899e+00]
the error rate of this test is:0.328358
权重向量: [ 15.41657995 -1.7775319 1.22496139 -1.37359964 0.97737564
-6.09300482 4.60908661 -7.7380833 -6.78073168 -9.14229183
20.08345854 -23.19000154 25.8841097 11.16700552 -11.16666343
3.94740454 -3.29618226 0.34338709 0.40671626 -6.01183123
-1.92945607]
the error rate of this test is:0.328358
权重向量: [ 16.01083293 -2.66867704 1.44619268 -1.4758387 1.23934799
-5.07424509 3.99105886 -6.72766373 -7.85744122 -8.26270927
19.44129549 -24.03073301 27.64012095 10.37589911 -10.5526026
3.79551252 -3.68283135 0.13239338 -0.35005111 -6.34665997
-2.1934238 ]
the error rate of this test is:0.417910
权重向量: [ 13.73345524 -2.04977492 1.43323051 -1.87623337 0.94828083
-4.73969282 4.45417792 -7.57892025 -7.54221517 -8.3387618
20.21284442 -23.87435702 26.93161674 11.3585979 -11.72080447
3.70013146 -4.24441599 0.45971015 -0.27127358 -5.33174129
-2.11152087]
the error rate of this test is:0.402985
权重向量: [ 16.4536017 -2.71192141 1.51957008 -1.03381621 0.8125723
-5.07876785 4.4157913 -7.23451982 -6.11025706 -9.21167923
20.86843755 -23.16594286 27.20612196 10.73479633 -10.47205887
4.24299353 -3.54023733 -0.2715715 -0.59786935 -4.33051587
-2.30642361]
the error rate of this test is:0.388060
权重向量: [ 16.51313194 -2.49072519 2.04414683 -1.14788396 0.92670114
-5.26760991 3.85309816 -7.21619825 -9.11456637 -8.63518754
20.60040907 -23.94924438 27.93833898 10.76430779 -11.31338169
4.6565549 -3.83701588 -0.4700557 -0.03909479 -5.5659964
-2.66949478]
the error rate of this test is:0.343284
权重向量: [ 18.55529534 -1.87912569 1.83719755 -1.52667051 1.37774381
-4.89643812 3.86031024 -7.12933551 -8.484749 -7.8020516
19.70134522 -24.15868032 27.33220816 11.75108273 -11.42585573
3.48514573 -3.37322907 -0.10790695 0.47341688 -5.27229335
-1.62493064]
the error rate of this test is:0.343284
权重向量: [ 14.92840677 -2.25301854 1.91621959 -1.57895156 1.10141376
-5.39616131 5.23520837 -6.77424035 -9.3037617 -8.91582214
20.88464734 -24.57218471 27.23249816 12.42044801 -11.71615474
3.38022571 -2.8415815 -0.71553992 0.53838514 -4.63944673
-2.46458178]
the error rate of this test is:0.388060
权重向量: [ 16.79923572 -2.1115749 1.35687685 -1.22022731 1.42524975
-5.28729152 4.28135548 -6.55468089 -7.73119328 -9.15710281
20.16741212 -22.87745152 28.19544042 10.42505585 -10.16404172
3.75571354 -3.7527875 -0.4665408 0.04066664 -4.72233482
-2.31501088]
the error rate of this test is:0.373134
after 10 iterations the average error rate is:0.365672
Process finished with exit code 0
---------------------
【转载】
作者:不曾走远~
原文:https://blog.csdn.net/qq_20412595/article/details/82501280
|
|