1.直接上代码代码第一部分
dataiter = iter(testloader) # 创建一个python迭代器,读入的是我们第一步里面就已经加载好的testloader
images, labels = dataiter.next() # 返回一个batch_size的图片,根据第一步的设置,应该是4张
# print images
imshow(torchvision.utils.make_grid(images)) # 展示这四张图片
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) # python字符串格式化 ' '.join表示用空格来连接后面的字符串,参考python的join()方法
这一部分代码就是先随机读取4张图片,让我们看看这四张图片是什么并打印出相应的label信息, 因为第一步里面设置了是shuffle了数据的,也就是顺序是打乱的,所以各自出现的图像不一定相同, 代码第二部分
outputs = net(Variable(images)) # 注意这里的images是我们从上面获得的那四张图片,所以首先要转化成variable
_, predicted = torch.max(outputs.data, 1)
# 这个 _ , predicted是python的一种常用的写法,表示后面的函数其实会返回两个值
# 但是我们对第一个值不感兴趣,就写个_在那里,把它赋值给_就好,我们只关心第二个值predicted
# 比如 _ ,a = 1,2 这中赋值语句在python中是可以通过的,你只关心后面的等式中的第二个位置的值是多少
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4))) # python的字符串格式化
这里用到了torch.max( ), 它是属于Tensor的一个方法:
注意到注释中第一句话,是说返回返回输入Tensor中每行的最大值,并转换成指定的dim(维度), 所以我们程序中的 torch.max(outputs.data, 1) ,很明显就是返回一个列,列元素是输入的outputs.data的每行最大值 而这里很明显,这个返回的列的第一个元素是image data,第二个元素是label, 我们只需要label, 所以就会有 _ , predicted这样的赋值语句,我在注释中也说明了这是什么意思
代码第三部分
correct = 0 # 定义预测正确的图片数,初始化为0
total = 0 # 总共参与测试的图片数,也初始化为0
for data in testloader: # 循环每一个batch
images, labels = data
outputs = net(Variable(images)) # 输入网络进行测试
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0) # 更新测试图片的数量
correct += (predicted == labels).sum() # 更新正确分类的图片的数量
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total)) # 最后打印结果
tutorial给的结果是53%
代码第四部分来测试一下每一类的分类正确率
class_correct = list(0. for i in range(10)) # 定义一个存储每类中测试正确的个数的 列表,初始化为0
class_total = list(0. for i in range(10)) # 定义一个存储每类中测试总数的个数的 列表,初始化为0
for data in testloader: # 以一个batch为单位进行循环
images, labels = data
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, 1)
c = (predicted == labels).squeeze()
for i in range(4): # 因为每个batch都有4张图片,所以还需要一个4的小循环
label = labels # 对各个类的进行各自累加
class_correct[label] += c
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes, 100 * class_correct / class_total))
|