介绍完回归任务下的GBDT后,这篇文章将介绍在分类任务下的GBDT,大家将可以看到,对于回归和分类,其实GBDT过程简直就是一模一样的。如果说最大的不同的话,那就是在于由于loss function不同而引起的初始化不同、叶子节点取值不同。
正文:GB的一些基本原理都已经在上文中介绍了,下面直接进入正题。
下面是分类任务的GBDT算法过程,其中选用的loss function是logloss。
<span class="MathJax" id="MathJax-Element-152-Frame" tabindex="0" data-mathml="L(yi,Fm(xi))=−{yilogpi+(1−yi)log(1−pi)}" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">L(yi,Fm(xi))=−{yilogpi+(1−yi)log(1−pi)}L(yi,Fm(xi))=−{yilogpi+(1−yi)log(1−pi)}。
其中<span class="MathJax" id="MathJax-Element-153-Frame" tabindex="0" data-mathml="pi=11+e(−Fm(xi))" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">pi=11+e(−Fm(xi))pi=11+e(−Fm(xi))
这里简单推导一下logloss通常化简后的式子:
<span class="MathJax" id="MathJax-Element-154-Frame" tabindex="0" data-mathml="L(yi,Fm(xi))=−{yilogpi+(1−yi)log(1−pi)}" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">L(yi,Fm(xi))=−{yilogpi+(1−yi)log(1−pi)}L(yi,Fm(xi))=−{yilogpi+(1−yi)log(1−pi)}
(先不带入负号)
带入<span class="MathJax" id="MathJax-Element-155-Frame" tabindex="0" data-mathml="pi" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">pipi=><span class="MathJax" id="MathJax-Element-156-Frame" tabindex="0" data-mathml="yilog(11+e(−Fm(xi)))+(1−yi)log(e(−Fm(xi))1+e(−Fm(xi)))" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">yilog(11+e(−Fm(xi)))+(1−yi)log(e(−Fm(xi))1+e(−Fm(xi)))yilog(11+e(−Fm(xi)))+(1−yi)log(e(−Fm(xi))1+e(−Fm(xi)))
=><span class="MathJax" id="MathJax-Element-157-Frame" tabindex="0" data-mathml="−yilog(1+e(−Fm(xi)))+(1−yi){log(e(−Fm(xi)))−log(1+e(−Fm(xi)))}" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">−yilog(1+e(−Fm(xi)))+(1−yi){log(e(−Fm(xi)))−log(1+e(−Fm(xi)))}−yilog(1+e(−Fm(xi)))+(1−yi){log(e(−Fm(xi)))−log(1+e(−Fm(xi)))}
=><span class="MathJax" id="MathJax-Element-158-Frame" tabindex="0" data-mathml="−yilog(1+e(−Fm(xi)))+log(e(−Fm(xi)))−log(1+e(−Fm(xi)))−yilog(e(−Fm(xi)))+yilog(1+e(−Fm(xi)))" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">−yilog(1+e(−Fm(xi)))+log(e(−Fm(xi)))−log(1+e(−Fm(xi)))−yilog(e(−Fm(xi)))+yilog(1+e(−Fm(xi)))−yilog(1+e(−Fm(xi)))+log(e(−Fm(xi)))−log(1+e(−Fm(xi)))−yilog(e(−Fm(xi)))+yilog(1+e(−Fm(xi)))
=><span class="MathJax" id="MathJax-Element-159-Frame" tabindex="0" data-mathml="yiFm(xi)−log(1+eFm(xi))" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">yiFm(xi)−log(1+eFm(xi))yiFm(xi)−log(1+eFm(xi))
最后加上负号可以得:
<span class="MathJax" id="MathJax-Element-160-Frame" tabindex="0" data-mathml="L(yi,Fm(xi))=−{yilogpi+(1−yi)log(1−pi)}=−{yiFm(xi)−log(1+eFm(xi))}" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">L(yi,Fm(xi))=−{yilogpi+(1−yi)log(1−pi)}=−{yiFm(xi)−log(1+eFm(xi))}L(yi,Fm(xi))=−{yilogpi+(1−yi)log(1−pi)}=−{yiFm(xi)−log(1+eFm(xi))}
算法3就是GBDT用于分类任务时,loss funcion选用logloss的算法流程。
可以看到,和回归任务是一样的,并没有什么特殊的处理环节。
(其实在sklearn源码里面,虽然回归任务的模型定义是GradientBoostingRegressor()而分类任务是GradientBoostingClassifier(),但是这两者区分开来是为了方便用户使用,最终两者都是共同继承BaseGradientBoosting(),算法3这些流程都是在BaseGradientBoosting()完成的,GradientBoostingRegressor()、GradientBoostingClassifier()只是完成一些学习器参数配置的任务)
下面同样以一个简单的数据集来大致的介绍一下GBDT的过程。
参数配置:
1. 以logloss为损失函数
2. 以MSE为分裂准则
3. 树的深度为1
4. 学习率为0.1
算法3的第一步,初始化。
<span class="MathJax" id="MathJax-Element-66-Frame" tabindex="0" data-mathml="F0(x)=log(∑i=1Nyi∑i=1N(1−yi))=log(46)=−0.4054" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">F0(x)=log(∑Ni=1yi∑Ni=1(1−yi))=log(46)=−0.4054F0(x)=log(∑i=1Nyi∑i=1N(1−yi))=log(46)=−0.4054
拟合第一颗树(<span class="MathJax" id="MathJax-Element-67-Frame" tabindex="0" data-mathml="m=1" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">m=1m=1)
计算负梯度值:
<span class="MathJax" id="MathJax-Element-68-Frame" tabindex="0" data-mathml="yi~=−[∂L(yi,F(xi))∂F(xi)]F(x)=Fm−1(x)=yi−11+e(−Fm−1(xi))=yi−11+e(−F0(xi))" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">yi~=−[∂L(yi,F(xi))∂F(xi)]F(x)=Fm−1(x)=yi−11+e(−Fm−1(xi))=yi−11+e(−F0(xi))yi~=−[∂L(yi,F(xi))∂F(xi)]F(x)=Fm−1(x)=yi−11+e(−Fm−1(xi))=yi−11+e(−F0(xi))
比如计算第一个样本(<span class="MathJax" id="MathJax-Element-69-Frame" tabindex="0" data-mathml="i=1" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">i=1i=1)有:
<span class="MathJax" id="MathJax-Element-70-Frame" tabindex="0" data-mathml="y1~=0−11+e(0.4054)=−0.400" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">y1~=0−11+e(0.4054)=−0.400y1~=0−11+e(0.4054)=−0.400
同样地,其他计算后如下表:
接着,我们需要以<span class="MathJax" id="MathJax-Element-73-Frame" tabindex="0" data-mathml="y~i" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">y~iy~i为目标,拟合一颗树。
拟合树的过程上篇文章已经详细介绍了,这里就不再累述了。拟合完后结果如下:
可以得出建好树之后叶子节点的区域:
<span class="MathJax" id="MathJax-Element-74-Frame" tabindex="0" data-mathml="R11" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">R11R11为<span class="MathJax" id="MathJax-Element-75-Frame" tabindex="0" data-mathml="xi<=8" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">xi<=8xi<=8,<span class="MathJax" id="MathJax-Element-76-Frame" tabindex="0" data-mathml="R21" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">R21R21为<span class="MathJax" id="MathJax-Element-77-Frame" tabindex="0" data-mathml="xi>8" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">xi>8xi>8
下面计算可以叶子节点的值<span class="MathJax" id="MathJax-Element-78-Frame" tabindex="0" data-mathml="γjm" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">γjmγjm
由公式:<span class="MathJax" id="MathJax-Element-79-Frame" tabindex="0" data-mathml="γjm=∑xi∈Rjmy~i∑xi∈Rjm(yi−y~i)∗(1−yi+y~i)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">γjm=∑xi∈Rjmy~i∑xi∈Rjm(yi−y~i)∗(1−yi+y~i)γjm=∑xi∈Rjmy~i∑xi∈Rjm(yi−y~i)∗(1−yi+y~i)
对于区域<span class="MathJax" id="MathJax-Element-80-Frame" tabindex="0" data-mathml="R11" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">R11R11有如下:
<span class="MathJax" id="MathJax-Element-81-Frame" tabindex="0" data-mathml="∑xi∈R11y~i=(y~1+y~2+y~3+y~4+y~5+y~6+y~7+y~8)=−1.2" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">∑xi∈R11y~i=(y~1+y~2+y~3+y~4+y~5+y~6+y~7+y~8)=−1.2∑xi∈R11y~i=(y~1+y~2+y~3+y~4+y~5+y~6+y~7+y~8)=−1.2
<span class="MathJax" id="MathJax-Element-82-Frame" tabindex="0" data-mathml="∑xi∈R11(yi−y~i)∗(1−yi+y~i)=(y1−y~1)∗(1−y1+y~1)+(y2−y~2)∗(1−y2+y~2)+(y3−y~3)∗(1−y3+y~3)+(y4−y~4)∗(1−y4+y~4)+(y5−y~5)∗(1−y5+y~5)+(y6−y~6)∗(1−y6+y~6)+(y7−y~7)∗(1−y7+y~7)+(y8−y~8)∗(1−y8+y~8)=1.92" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">∑xi∈R11(yi−y~i)∗(1−yi+y~i)=(y1−y~1)∗(1−y1+y~1)+(y2−y~2)∗(1−y2+y~2)+(y3−y~3)∗(1−y3+y~3)+(y4−y~4)∗(1−y4+y~4)+(y5−y~5)∗(1−y5+y~5)+(y6−y~6)∗(1−y6+y~6)+(y7−y~7)∗(1−y7+y~7)+(y8−y~8)∗(1−y8+y~8)=1.92∑xi∈R11(yi−y~i)∗(1−yi+y~i)=(y1−y~1)∗(1−y1+y~1)+(y2−y~2)∗(1−y2+y~2)+(y3−y~3)∗(1−y3+y~3)+(y4−y~4)∗(1−y4+y~4)+(y5−y~5)∗(1−y5+y~5)+(y6−y~6)∗(1−y6+y~6)+(y7−y~7)∗(1−y7+y~7)+(y8−y~8)∗(1−y8+y~8)=1.92
对于区域<span class="MathJax" id="MathJax-Element-83-Frame" tabindex="0" data-mathml="R21" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">R21R21有如下:
<span class="MathJax" id="MathJax-Element-84-Frame" tabindex="0" data-mathml="∑xi∈R21y~i=(y~9+y~10)=1.2" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">∑xi∈R21y~i=(y~9+y~10)=1.2∑xi∈R21y~i=(y~9+y~10)=1.2
<span class="MathJax" id="MathJax-Element-85-Frame" tabindex="0" data-mathml="∑xi∈R21(yi−y~i)∗(1−yi+y~i)=(y9−y~9)∗(1−y9+y~9)+(y10−y~10)∗(1−y10+y~10)=0.48" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">∑xi∈R21(yi−y~i)∗(1−yi+y~i)=(y9−y~9)∗(1−y9+y~9)+(y10−y~10)∗(1−y10+y~10)=0.48∑xi∈R21(yi−y~i)∗(1−yi+y~i)=(y9−y~9)∗(1−y9+y~9)+(y10−y~10)∗(1−y10+y~10)=0.48
故最后可以得到两个叶子节点的值:
<span class="MathJax" id="MathJax-Element-86-Frame" tabindex="0" data-mathml="γ11=−1.21.92=−0.625" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">γ11=−1.21.92=−0.625γ11=−1.21.92=−0.625、<span class="MathJax" id="MathJax-Element-87-Frame" tabindex="0" data-mathml="γ21=1.20.480=2.5" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">γ21=1.20.480=2.5γ21=1.20.480=2.5
最后通过<span class="MathJax" id="MathJax-Element-88-Frame" tabindex="0" data-mathml="Fm(x)=Fm−1(x)+∑j=1JγjmI(x∈Rjm)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">Fm(x)=Fm−1(x)+∑Jj=1γjmI(x∈Rjm)Fm(x)=Fm−1(x)+∑j=1JγjmI(x∈Rjm)更新<span class="MathJax" id="MathJax-Element-89-Frame" tabindex="0" data-mathml="F1(x)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">F1(x)F1(x),需要注意的是,这里同样也用shrinkage,即乘一个学习率<span class="MathJax" id="MathJax-Element-90-Frame" tabindex="0" data-mathml="η" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">ηη,具体表现为:
<span class="MathJax" id="MathJax-Element-91-Frame" tabindex="0" data-mathml="Fm(x)=Fm−1(x)+η∗∑j=1JγjmI(x∈Rjm)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">Fm(x)=Fm−1(x)+η∗∑Jj=1γjmI(x∈Rjm)Fm(x)=Fm−1(x)+η∗∑j=1JγjmI(x∈Rjm)。
以计算<span class="MathJax" id="MathJax-Element-92-Frame" tabindex="0" data-mathml="x1" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">x1x1为例:
<span class="MathJax" id="MathJax-Element-93-Frame" tabindex="0" data-mathml="F1(x1)=F0(x1)+0.1∗(−0.625)=−0.4054−0.0625=−0.4679" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">F1(x1)=F0(x1)+0.1∗(−0.625)=−0.4054−0.0625=−0.4679F1(x1)=F0(x1)+0.1∗(−0.625)=−0.4054−0.0625=−0.4679
其他计算完毕后如下表供参考:
至此,第一颗树已经训练完成。可以再次看到其训练过程和回归基本没有区别。
下面简单提一下拟合第二颗树(<span class="MathJax" id="MathJax-Element-96-Frame" tabindex="0" data-mathml="m=2)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">m=2)m=2)
计算负梯度值:
比如对于<span class="MathJax" id="MathJax-Element-97-Frame" tabindex="0" data-mathml="x1" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">x1x1有:
=><span class="MathJax" id="MathJax-Element-98-Frame" tabindex="0" data-mathml="y~1=y1−11+e(−F1(x1))=0−0.38509=−0.38509" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">y~1=y1−11+e(−F1(x1))=0−0.38509=−0.38509y~1=y1−11+e(−F1(x1))=0−0.38509=−0.38509
其他同理,可得下表:
之后也是以新的<span class="MathJax" id="MathJax-Element-101-Frame" tabindex="0" data-mathml="y~i" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">y~iy~i为目标拟合一颗回归树后计算叶子节点的区间和叶子节点的值。
当只有2颗树的时候,其预测过程也是和下面这个图一样
相比于回归任务,分类任务需把要最后累加的结果<span class="MathJax" id="MathJax-Element-102-Frame" tabindex="0" data-mathml="Fm(x)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">Fm(x)Fm(x)转成概率。(其实<span class="MathJax" id="MathJax-Element-103-Frame" tabindex="0" data-mathml="Fm(x)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">Fm(x)Fm(x)可以理解成一个得分)。具体来说:
对于采用logloss作为损失函数的情况下,<span class="MathJax" id="MathJax-Element-104-Frame" tabindex="0" data-mathml="pi=11+e(−Fm(xi))" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">pi=11+e(−Fm(xi))pi=11+e(−Fm(xi))。
对于采用指数损失作为损失函数的情况下,<span class="MathJax" id="MathJax-Element-105-Frame" tabindex="0" data-mathml="pi=11+e(−2Fm(xi))" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">pi=11+e(−2Fm(xi))pi=11+e(−2Fm(xi))。
当然这里的<span class="MathJax" id="MathJax-Element-106-Frame" tabindex="0" data-mathml="pi" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">pipi指的是正样本的概率。
这里再详细一点,比如对于上面例子,当我们拟合完第二颗树后,计算<span class="MathJax" id="MathJax-Element-107-Frame" tabindex="0" data-mathml="F2(x)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">F2(x)F2(x)可有有下表:
此时计算相应的概率值有:
<span class="MathJax" id="MathJax-Element-110-Frame" tabindex="0" data-mathml="F2(x)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">F2(x)F2(x)可有有下表:
(表中的概率为正样本的概率,即<span class="MathJax" id="MathJax-Element-113-Frame" tabindex="0" data-mathml="yi=1" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">yi=1yi=1的概率)
Sklearn源码简单分析写在前面:Sklearn源码分析后面有时间有添加一些内容,下面先简单了解GDBT分类的核心代码。
当loss function选用logloss时,对应的是sklearn里面的loss=’deviance’。
计算负梯度、初始化、更新叶子节点、转成概率都在一个名叫BinomialDeviance()的类中。
下面这是用于计算负梯度值。注意的函数expit就是<span class="MathJax" id="MathJax-Element-114-Frame" tabindex="0" data-mathml="11+e−x" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">11+e−x11+e−x
代码中的y_pred或者pred表达的就是<span class="MathJax" id="MathJax-Element-115-Frame" tabindex="0" data-mathml="Fm−1(x)" role="presentation" style="box-sizing: border-box; outline: 0px; display: inline; line-height: normal; text-align: left; word-spacing: normal; word-wrap: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; word-break: break-all; position: relative;">Fm−1(x)Fm−1(x)
更新叶子节点,关键在于计算numerator和denominator。
另外代码里的residual代表的是负梯度值。
初始化的类:
class LogOddsEstimator(object): """An estimator predicting the log odds ratio.""" scale = 1.0 def fit(self, X, y, sample_weight=None): # pre-cond: pos, neg are encoded as 1, 0 if sample_weight is None: pos = np.sum(y) neg = y.shape[0] - pos else: pos = np.sum(sample_weight * y) neg = np.sum(sample_weight * (1 - y)) if neg == 0 or pos == 0: raise ValueError('y contains non binary labels.') self.prior = self.scale * np.log(pos / neg) def predict(self, X): check_is_fitted(self, 'prior') y = np.empty((X.shape[0], 1), dtype=np.float64) y.fill(self.prior) return y其中,下面这个用于初始化,可以看到有一个因子self.scale,这是由于在Sklearn里提供两种loss function用于分类,一种是logloss,一种是指数损失,两者的初始化仅仅只是在系数上不同,前者是1.0,后者是0.5。
def fit(self, X, y, sample_weight=None): # pre-cond: pos, neg are encoded as 1, 0 if sample_weight is None: pos = np.sum(y) neg = y.shape[0] - pos else: pos = np.sum(sample_weight * y) neg = np.sum(sample_weight * (1 - y)) if neg == 0 or pos == 0: raise ValueError('y contains non binary labels.') self.prior = self.scale * np.log(pos / neg)最后是转化成概率,这里有个细节,就是正样本的概率是放在第2列(从1数起)。
def _score_to_proba(self, score): proba = np.ones((score.shape[0], 2), dtype=np.float64) proba[:, 1] = expit(score.ravel()) proba[:, 0] -= proba[:, 1] return proba至此,GBDT用于回归和分类的两种情况都已经说明完毕,欠缺的可能是源码部分说的不够深入,由于最近时间的关系没办法做到太深入,所以后面找时间会把代码再深入的分析后补充在这。
欢迎光临 黑马程序员技术交流社区 (http://bbs.itheima.com/) | 黑马程序员IT技术论坛 X3.2 |