黑马程序员技术交流社区
标题:
【上海校区】Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)...
[打印本页]
作者:
不二晨
时间:
2018-9-5 10:14
标题:
【上海校区】Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)...
1、简述
经过前面的数据加载和网络定义后,就可以开始训练了,这里会看到前面遇到的一些东西究竟在后面会有什么用,所以这一步希望各位也能仔细研究一下
2、代码
for
epoch
in
range(
2
):
# loop over the dataset multiple times 指定训练一共要循环几个epoch
running_loss =
0.0
#定义一个变量方便我们对loss进行输出
for
i, data
in
enumerate(trainloader,
0
):
# 这里我们遇到了第一步中出现的trailoader,代码传入数据
# enumerate是python的内置函数,既获得索引也获得数据,详见下文
# get the inputs
inputs, labels = data
# data是从enumerate返回的data,包含数据和标签信息,分别赋值给inputs和labels
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels)
# 将数据转换成Variable,第二步里面我们已经引入这个模块
# 所以这段程序里面就直接使用了,下文会分析
# zero the parameter gradients
optimizer.zero_grad()
# 要把梯度重新归零,因为反向传播过程中梯度会累加上一次循环的梯度
# forward + backward + optimize
outputs = net(inputs)
# 把数据输进网络net,这个net()在第二步的代码最后一行我们已经定义了
loss = criterion(outputs, labels)
# 计算损失值,criterion我们在第三步里面定义了
loss.backward()
# loss进行反向传播,下文详解
optimizer.step()
# 当执行反向传播之后,把优化器的参数进行更新,以便进行下一轮
# print statistics # 这几行代码不是必须的,为了打印出loss方便我们看而已,不影响训练过程
running_loss += loss.data[
0
]
# 从下面一行代码可以看出它是每循环0-1999共两千次才打印一次
if
i %
2000
==
1999
:
# print every 2000 mini-batches 所以每个2000次之类先用running_loss进行累加
print(
'[%d, %5d] loss: %.3f'
%
(epoch +
1
, i +
1
, running_loss /
2000
))
# 然后再除以2000,就得到这两千次的平均损失值
running_loss =
0.0
# 这一个2000次结束后,就把running_loss归零,下一个2000次继续使用
print(
'Finished Training'
)
[python]
view plain
copy
3、分析
①autograd
在第二步中我们定义网络时定义了前向传播函数,但是并没有定义反向传播函数,可是深度学习是需要反向传播求导的,
Pytorch其实利用的是Autograd模块来进行自动求导,反向传播。
Autograd中最核心的类就是Variable了,它封装了Tensor,并几乎支持所有Tensor的操作,这里可以参考官方给的详细解释:
http://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autograd-tutorial-py
以上链接详细讲述了variable究竟是怎么能够实现自动求导的,怎么用它来实现反向传播的。
这里涉及到计算图的相关概念,这里我不详细讲,后面会写相关博文来讨论这个东西,暂时不会对我们理解这个程序造成影响
只说一句,
想要计算各个variable的梯度,只需调用根节点的backward方法,Autograd就会自动沿着整个计算图进行反向计算
而在此例子中,根节点就是我们的loss
,所以:
程序中的loss.backward()代码就是在实现反向传播,自动计算所有的梯度。
所以训练部分的代码其实比较简单:
running_loss和后面负责打印损失值的那部分并不是必须的,所以关键行不多,总得来说分成三小节
第一节:把最开始放在trainloader里面的数据给转换成variable,然后指定为网络的输入;
第二节:每次循环新开始的时候,要确保梯度归零
第三节:forward+backward,就是调用我们在第三步里面实例化的net()实现前传,loss.backward()实现后传
每结束一次循环,要确保梯度更新
作者:
不二晨
时间:
2018-9-6 11:09
奈斯
作者:
魔都黑马少年梦
时间:
2018-11-1 16:49
欢迎光临 黑马程序员技术交流社区 (http://bbs.itheima.com/)
黑马程序员IT技术论坛 X3.2