黑马程序员技术交流社区
标题: 【上海校区】Pytorch实现人脸多属性识别 [打印本页]
作者: 梦缠绕的时候 时间: 2018-6-26 09:32
标题: 【上海校区】Pytorch实现人脸多属性识别
本帖最后由 梦缠绕的时候 于 2018-6-26 09:34 编辑
数据来源:CalebA人脸数据集(官网链接)是香港中文大学的开放数据,包含10,177个名人身份的202,599张人脸图片,并且都做好了特征标记,这对人脸相关的训练是非常好用的数据集。共计40个特征,具体是哪些特征,可以去官网查询。话不多说,直接开始流程。
整个流程可以分为大致以下几个步骤:
1.图片预处理
2.构建网络
3.训练
4.测试
5.优化
一。图片加载,以为源数据没有经过处理,我们要重写torch.utils.data.Dataloader()处理图片,然后才能将图片用于加载。代码如下:
- def default_loader(path):
- try:
- img = Image.open(path)
- return img.convert('RGB')
- except:
- print("Can not open {0}".format(path))
- class myDataset(Data.DataLoader):
- def __init__(self,img_dir,img_txt=img_txt,transform=None,loader=default_loader):
- img_list = []
- img_labels = []
-
- fp = open(img_txt,'r')
- for line in fp.readlines():
- if len(line.split())!=41:
- continue
- img_list.append(line.split()[0])
- img_label_single = []
- for value in line.split()[1:]:
- if value == '-1':
- img_label_single.append(0)
- if value == '1':
- img_label_single.append(1)
- img_labels.append(img_label_single)
- self.imgs = [os.path.join(img_dir,file) for file in img_list]
- self.labels = img_labels
- self.transform = transform
- self.loader = loader
- def __len__(self):
- return len(self.imgs)
- def __getitem__(self,index):
- img_path = self.imgs[index]
- label = torch.from_numpy(np.array(self.labels[index],dtype=np.int64))
- img = self.loader(img_path)
- if self.transform is not None:
- try:
- img = self.transform(img)
- except:
- print('Cannot transform image: {}'.format(img_path))
- return img,label
图片增强、归一化处理和加载:
- transform = transforms.Compose([
- transforms.Resize(40),
- transforms.CenterCrop(32),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.5,0.5,0.5],
- std = [0.5,0.5,0.5])
- ])
[python] view plain copy
train_dataset = myDataset(img_dir=img_root,img_txt=train_txt,transform= transform)train_dataloader = Data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
- test_dataset = myDataset(img_dir=img_root,img_txt = test_txt,transform= transform)
- test_dataloader = Data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
构建网络:我使用的网络结构是每种属性使用3层卷积加上3层fc。网络结构比较简单,导致准确率不会有太高的表现,如果有兴趣可以做下优化,文末有优化的思路供大家讨论。好了,先上代码:
- def make_conv():
- return nn.Sequential(
- nn.Conv2d(3,16,3,1,1),
- nn.ReLU(),
- nn.MaxPool2d(2),
- nn.Conv2d(16,32,3,1,1),
- nn.ReLU(),
- nn.MaxPool2d(2),
- nn.Conv2d(32,64,3,1,1),
- nn.ReLU(),
- #nn.Dropout(0.5),
- nn.MaxPool2d(2)
- )
- def make_fc():
- return nn.Sequential(
- nn.Linear(64*4*4,128),
- nn.ReLU(),
- #nn.Dropout(0.5),
- nn.Linear(128,64),
- nn.ReLU(),
- nn.Dropout(0.5),#Dropout()可以一定程度上防止过拟合,放在不同位置或许会有意想不到的结果,有条件可以多尝试几次
- nn.Linear(64,2)
- )
- class face_attr(nn.Module):
- def __init__(self):
- super(face_attr,self).__init__()
- #attr0
- self.attr0_layer1 = make_conv()
- self.attr0_layer2 = make_fc()
- #attr1
- self.attr1_layer1 = make_conv()
- self.attr1_layer2 = make_fc()
- ...#每一中属性的计算都是相同的,在文中省略,
- #attr38
- self.attr38_layer1 = make_conv()
- self.attr38_layer2 = make_fc()
- #attr39
- self.attr39_layer1 = make_conv()
- self.attr39_layer2 = make_fc()
- def forward(self,x):
- out_list = []
- #out0
- out0 = self.attr0_layer1(x)
- out0 = out0.view(out0.size(0),-1)
- out0 = self.attr0_layer2(out0)
- out_list.append(out0)
- ...
- #out39
- out39 = self.attr39_layer1(x)
- out39 = out39.view(out39.size(0),-1)
- out39 = self.attr39_layer2(out39)
- out_list.append(out39)
-
- return out_list
接下来就可以开始训练网络了,定义优化器的时候可以设置一下weight_decay=1e-8,也可以在一定程度上防止过拟合。
- module = face_attr()
- #print(module)
-
-
- optimizer = optim.Adam(module.parameters(),lr = 0.001,weight_decay=1e-8)
-
- loss_list = []
- for i in range(40):
- loss_func = nn.CrossEntropyLoss()
- loss_list.append(loss_func)
- #loss_func = nn.CrossEntropyLoss()
- for Epoch in range(50):
- all_correct_num = 0
- for ii,(img,label) in enumerate(train_dataloader):
-
- img = Variable(img)
- label = Variable(label)
- output = module(img)
- optimizer.zero_grad()
- for i in range(40):
- loss = loss_list(output,label[:,i])
- loss.backward()
- _,predict = torch.max(output,1)
- correct_num = sum(predict==label[:,i])
- all_correct_num += correct_num.data[0]
- optimizer.step()
-
-
- Accuracy = all_correct_num *1.0/(len(train_dataset)*40.0)
- print('Epoch ={0},all_correct_num={1},Accuracy={2}'.format(Epoch,all_correct_num,Accuracy))
-
- torch.save(module,'W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')#每跑一个epoch就保存一次模型
测试网络:和训练类似,只是不用优化和做反向传播。
- module = torch.load('W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')#加载刚刚保存的网络
- module.eval()#改成测试模式
- all_correct_num = 0
- for ii,(img,label) in enumerate(test_dataloader):
-
- img = Variable(img)
- label = Variable(label)
- output = module(img)
- for i in range(40):
- _,predict = torch.max(output,1)
- correct_num = sum(predict==label[:,i])
- all_correct_num += correct_num.data[0]
- Accuracy = all_correct_num *1.0/(len(test_dataset)*40.0)
- print('all_correct_num={0},Accuracy={1}'.format(all_correct_num,Accuracy))
总结:我因为是笔记本电脑,没有GPU,所以只给了大约5000个数据用于训练(即使是5000个数据我的电脑也跑了2天才跑完50个epoch),1000个数据用于测试。测试的准确率在90%左右。有条件的同学可以做些优化,下面提供一些可以优化的方面:
1.图片增强,我因为电脑不给力,无法处理较大的数据,所以将原始图片缩放到40*40,然后截取了32*32作为输入,如果有GPU等条件,可以考虑128*128输入
2.最开始没加Dropout()的时候出现了过拟合的情况,当然这也和训练集较小有关系,建议训练集给到60000个样本。上述提到的Dropout()多尝试几个位置,我是放在了最后一层输出之前。期待大家尝试之后分享一下结果。
3.我现在每种属性都是使用相同的网络,但是这种网络有可能不是对于每种属性都是最优的选择,可以针对每一种属性单独写一层网络。例如:某个属性使用全fc层就可以达到很高的准确率,某个属性或许需要4层卷积+2层fc可以达到很好的效果,这种情况只有靠多尝试,算法什么的可能并不能给出哪种才是最适合的网络模型。建议输出每种属性的准确率,然后针对准确率较低的属性做相应的网络优化。
4.多准备几块GPU做训练和测试吧,没GPU真不给力。希望有小伙伴在这个网络的基础上能达到更高的准确率。
附上整体代码如下:
- # -*- coding: utf-8 -*-
- """
- Created on Sun Jun 17 11:54:36 2018
-
- @author: sky-hole
- """
-
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- import torch.optim as optim
- import torchvision.transforms as transforms
- import torch.utils.data as Data
- from PIL import Image
- import numpy as np
- import os
-
- img_root = 'W:/pic_data/face/CelebA/Img/img_align_celeba'
- train_txt = 'W:/pic_data/face/CelebA/Img/train10000.txt'
- batch_size = 2
-
- def default_loader(path):
- try:
- img = Image.open(path)
- return img.convert('RGB')
- except:
- print("Can not open {0}".format(path))
- class myDataset(Data.DataLoader):
- def __init__(self,img_dir,img_txt,transform=None,loader=default_loader):
- img_list = []
- img_labels = []
-
- fp = open(img_txt,'r')
- for line in fp.readlines():
- if len(line.split())!=41:
- continue
- img_list.append(line.split()[0])
- img_label_single = []
- for value in line.split()[1:]:
- if value == '-1':
- img_label_single.append(0)
- if value == '1':
- img_label_single.append(1)
- img_labels.append(img_label_single)
- self.imgs = [os.path.join(img_dir,file) for file in img_list]
- self.labels = img_labels
- self.transform = transform
- self.loader = loader
- def __len__(self):
- return len(self.imgs)
- def __getitem__(self,index):
- img_path = self.imgs[index]
- label = torch.from_numpy(np.array(self.labels[index],dtype=np.int64))
- img = self.loader(img_path)
- if self.transform is not None:
- try:
- img = self.transform(img)
- except:
- print('Cannot transform image: {}'.format(img_path))
- return img,label
- transform = transforms.Compose([
- transforms.Resize(40),
- transforms.CenterCrop(32),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.5,0.5,0.5],
- std = [0.5,0.5,0.5])
- ])
-
- train_dataset = myDataset(img_dir=img_root,img_txt=train_txt,transform= transform)
- train_dataloader = Data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
-
- #print(len(train_dataset))
- #print(len(train_dataloader))
- def make_conv():
- return nn.Sequential(
- nn.Conv2d(3,16,3,1,1),
- nn.ReLU(),
- nn.MaxPool2d(2),
- nn.Conv2d(16,32,3,1,1),
- nn.ReLU(),
- nn.MaxPool2d(2),
- nn.Conv2d(32,64,3,1,1),
- nn.ReLU(),
- #nn.Dropout(0.5),
- nn.MaxPool2d(2)
- )
- def make_fc():
- return nn.Sequential(
- nn.Linear(64*4*4,128),
- nn.ReLU(),
- #nn.Dropout(0.5),
- nn.Linear(128,64),
- nn.ReLU(),
- nn.Dropout(0.5),
- nn.Linear(64,2)
- )
- class face_attr(nn.Module):
- def __init__(self):
- super(face_attr,self).__init__()
- #attr0
- self.attr0_layer1 = make_conv()
- self.attr0_layer2 = make_fc()
- #attr1
- self.attr1_layer1 = make_conv()
- self.attr1_layer2 = make_fc()
- #attr2
- self.attr2_layer1 = make_conv()
- self.attr2_layer2 = make_fc()
- #attr3
- self.attr3_layer1 = make_conv()
- self.attr3_layer2 = make_fc()
- #attr4
- self.attr4_layer1 = make_conv()
- self.attr4_layer2 = make_fc()
- #attr5
- self.attr5_layer1 = make_conv()
- self.attr5_layer2 = make_fc()
- #attr6
- self.attr6_layer1 = make_conv()
- self.attr6_layer2 = make_fc()
- #attr7
- self.attr7_layer1 = make_conv()
- self.attr7_layer2 = make_fc()
- #attr8
- self.attr8_layer1 = make_conv()
- self.attr8_layer2 = make_fc()
- #attr9
- self.attr9_layer1 = make_conv()
- self.attr9_layer2 = make_fc()
- #attr10
- self.attr10_layer1 = make_conv()
- self.attr10_layer2 = make_fc()
- #attr11
- self.attr11_layer1 = make_conv()
- self.attr11_layer2 = make_fc()
- #attr12
- self.attr12_layer1 = make_conv()
- self.attr12_layer2 = make_fc()
- #attr13
- self.attr13_layer1 = make_conv()
- self.attr13_layer2 = make_fc()
- #attr14
- self.attr14_layer1 = make_conv()
- self.attr14_layer2 = make_fc()
- #attr15
- self.attr15_layer1 = make_conv()
- self.attr15_layer2 = make_fc()
- #attr16
- self.attr16_layer1 = make_conv()
- self.attr16_layer2 = make_fc()
- #attr17
- self.attr17_layer1 = make_conv()
- self.attr17_layer2 = make_fc()
- #attr18
- self.attr18_layer1 = make_conv()
- self.attr18_layer2 = make_fc()
- #attr19
- self.attr19_layer1 = make_conv()
- self.attr19_layer2 = make_fc()
- #attr20
- self.attr20_layer1 = make_conv()
- self.attr20_layer2 = make_fc()
- #attr21
- self.attr21_layer1 = make_conv()
- self.attr21_layer2 = make_fc()
- #attr22
- self.attr22_layer1 = make_conv()
- self.attr22_layer2 = make_fc()
- #attr23
- self.attr23_layer1 = make_conv()
- self.attr23_layer2 = make_fc()
- #attr24
- self.attr24_layer1 = make_conv()
- self.attr24_layer2 = make_fc()
- #attr25
- self.attr25_layer1 = make_conv()
- self.attr25_layer2 = make_fc()
- #attr26
- self.attr26_layer1 = make_conv()
- self.attr26_layer2 = make_fc()
- #attr27
- self.attr27_layer1 = make_conv()
- self.attr27_layer2 = make_fc()
- #attr28
- self.attr28_layer1 = make_conv()
- self.attr28_layer2 = make_fc()
- #attr29
- self.attr29_layer1 = make_conv()
- self.attr29_layer2 = make_fc()
- #attr30
- self.attr30_layer1 = make_conv()
- self.attr30_layer2 = make_fc()
- #attr31
- self.attr31_layer1 = make_conv()
- self.attr31_layer2 = make_fc()
- #attr32
- self.attr32_layer1 = make_conv()
- self.attr32_layer2 = make_fc()
- #attr33
- self.attr33_layer1 = make_conv()
- self.attr33_layer2 = make_fc()
- #attr34
- self.attr34_layer1 = make_conv()
- self.attr34_layer2 = make_fc()
- #attr35
- self.attr35_layer1 = make_conv()
- self.attr35_layer2 = make_fc()
- #attr36
- self.attr36_layer1 = make_conv()
- self.attr36_layer2 = make_fc()
- #attr37
- self.attr37_layer1 = make_conv()
- self.attr37_layer2 = make_fc()
- #attr38
- self.attr38_layer1 = make_conv()
- self.attr38_layer2 = make_fc()
- #attr39
- self.attr39_layer1 = make_conv()
- self.attr39_layer2 = make_fc()
- def forward(self,x):
- out_list = []
- #out0
- out0 = self.attr0_layer1(x)
- out0 = out0.view(out0.size(0),-1)
- out0 = self.attr0_layer2(out0)
- out_list.append(out0)
- #out1
- out1 = self.attr1_layer1(x)
- out1 = out1.view(out1.size(0),-1)
- out1 = self.attr1_layer2(out1)
- out_list.append(out1)
- #out2
- out2 = self.attr2_layer1(x)
- out2 = out2.view(out2.size(0),-1)
- out2 = self.attr2_layer2(out2)
- out_list.append(out2)
- #out3
- out3 = self.attr3_layer1(x)
- out3 = out3.view(out3.size(0),-1)
- out3 = self.attr3_layer2(out3)
- out_list.append(out3)
- #out4
- out4 = self.attr4_layer1(x)
- out4 = out4.view(out4.size(0),-1)
- out4 = self.attr4_layer2(out4)
- out_list.append(out4)
- #out5
- out5 = self.attr5_layer1(x)
- out5 = out5.view(out5.size(0),-1)
- out5 = self.attr5_layer2(out5)
- out_list.append(out5)
- #out6
- out6 = self.attr6_layer1(x)
- out6 = out6.view(out6.size(0),-1)
- out6 = self.attr6_layer2(out6)
- out_list.append(out6)
- #out7
- out7 = self.attr7_layer1(x)
- out7 = out7.view(out7.size(0),-1)
- out7 = self.attr7_layer2(out7)
- out_list.append(out7)
- #out8
- out8 = self.attr8_layer1(x)
- out8 = out8.view(out8.size(0),-1)
- out8 = self.attr8_layer2(out8)
- out_list.append(out8)
- #out9
- out9 = self.attr9_layer1(x)
- out9 = out9.view(out9.size(0),-1)
- out9 = self.attr9_layer2(out9)
- out_list.append(out9)
- #out10
- out10 = self.attr10_layer1(x)
- out10 = out10.view(out10.size(0),-1)
- out10 = self.attr10_layer2(out10)
- out_list.append(out10)
- #out11
- out11 = self.attr11_layer1(x)
- out11 = out11.view(out11.size(0),-1)
- out11 = self.attr11_layer2(out11)
- out_list.append(out11)
- #out12
- out12 = self.attr12_layer1(x)
- out12 = out12.view(out12.size(0),-1)
- out12 = self.attr12_layer2(out12)
- out_list.append(out12)
- #out13
- out13 = self.attr13_layer1(x)
- out13 = out13.view(out13.size(0),-1)
- out13 = self.attr13_layer2(out13)
- out_list.append(out13)
- #out14
- out14 = self.attr14_layer1(x)
- out14 = out14.view(out14.size(0),-1)
- out14 = self.attr14_layer2(out14)
- out_list.append(out14)
- #out15
- out15 = self.attr15_layer1(x)
- out15 = out15.view(out15.size(0),-1)
- out15 = self.attr15_layer2(out15)
- out_list.append(out15)
- #out16
- out16 = self.attr16_layer1(x)
- out16 = out16.view(out16.size(0),-1)
- out16 = self.attr16_layer2(out16)
- out_list.append(out16)
- #out17
- out17 = self.attr17_layer1(x)
- out17 = out17.view(out17.size(0),-1)
- out17 = self.attr17_layer2(out17)
- out_list.append(out17)
- #out18
- out18 = self.attr18_layer1(x)
- out18 = out18.view(out18.size(0),-1)
- out18 = self.attr18_layer2(out18)
- out_list.append(out18)
- #out19
- out19 = self.attr19_layer1(x)
- out19 = out19.view(out19.size(0),-1)
- out19 = self.attr19_layer2(out19)
- out_list.append(out19)
- #out20
- out20 = self.attr20_layer1(x)
- out20 = out20.view(out20.size(0),-1)
- out20 = self.attr20_layer2(out20)
- out_list.append(out20)
- #out21
- out21 = self.attr21_layer1(x)
- out21 = out21.view(out21.size(0),-1)
- out21 = self.attr21_layer2(out21)
- out_list.append(out21)
- #out22
- out22 = self.attr22_layer1(x)
- out22 = out22.view(out22.size(0),-1)
- out22 = self.attr22_layer2(out22)
- out_list.append(out22)
- #out23
- out23 = self.attr23_layer1(x)
- out23 = out23.view(out23.size(0),-1)
- out23 = self.attr23_layer2(out23)
- out_list.append(out23)
- #out24
- out24 = self.attr24_layer1(x)
- out24 = out24.view(out24.size(0),-1)
- out24 = self.attr24_layer2(out24)
- out_list.append(out24)
- #out25
- out25 = self.attr25_layer1(x)
- out25 = out25.view(out25.size(0),-1)
- out25 = self.attr25_layer2(out25)
- out_list.append(out25)
- #out26
- out26 = self.attr26_layer1(x)
- out26 = out26.view(out26.size(0),-1)
- out26 = self.attr26_layer2(out26)
- out_list.append(out26)
- #out27
- out27 = self.attr27_layer1(x)
- out27 = out27.view(out27.size(0),-1)
- out27 = self.attr27_layer2(out27)
- out_list.append(out27)
- #out28
- out28 = self.attr28_layer1(x)
- out28 = out28.view(out28.size(0),-1)
- out28 = self.attr28_layer2(out28)
- out_list.append(out28)
- #out29
- out29 = self.attr29_layer1(x)
- out29 = out29.view(out29.size(0),-1)
- out29 = self.attr29_layer2(out29)
- out_list.append(out29)
- #out30
- out30 = self.attr30_layer1(x)
- out30 = out30.view(out30.size(0),-1)
- out30 = self.attr30_layer2(out30)
- out_list.append(out30)
- #out31
- out31 = self.attr31_layer1(x)
- out31 = out31.view(out31.size(0),-1)
- out31 = self.attr31_layer2(out31)
- out_list.append(out31)
- #out32
- out32 = self.attr32_layer1(x)
- out32 = out32.view(out32.size(0),-1)
- out32 = self.attr32_layer2(out32)
- out_list.append(out32)
- #out33
- out33 = self.attr33_layer1(x)
- out33 = out33.view(out33.size(0),-1)
- out33 = self.attr33_layer2(out33)
- out_list.append(out33)
- #out34
- out34 = self.attr34_layer1(x)
- out34 = out34.view(out34.size(0),-1)
- out34 = self.attr34_layer2(out34)
- out_list.append(out34)
- #out35
- out35 = self.attr35_layer1(x)
- out35 = out35.view(out35.size(0),-1)
- out35 = self.attr35_layer2(out35)
- out_list.append(out35)
- #out36
- out36 = self.attr36_layer1(x)
- out36 = out36.view(out36.size(0),-1)
- out36 = self.attr36_layer2(out36)
- out_list.append(out36)
- #out37
- out37 = self.attr37_layer1(x)
- out37 = out37.view(out37.size(0),-1)
- out37 = self.attr37_layer2(out37)
- out_list.append(out37)
- #out38
- out38 = self.attr38_layer1(x)
- out38 = out38.view(out38.size(0),-1)
- out38 = self.attr38_layer2(out38)
- out_list.append(out38)
- #out39
- out39 = self.attr39_layer1(x)
- out39 = out39.view(out39.size(0),-1)
- out39 = self.attr39_layer2(out39)
- out_list.append(out39)
-
- return out_list
-
- module = face_attr()
- #print(module)
-
-
- optimizer = optim.Adam(module.parameters(),lr = 0.001,weight_decay=1e-8)
-
- loss_list = []
- for i in range(40):
- loss_func = nn.CrossEntropyLoss()
- loss_list.append(loss_func)
- #loss_func = nn.CrossEntropyLoss()
- for Epoch in range(50):
- all_correct_num = 0
- for ii,(img,label) in enumerate(train_dataloader):
-
- img = Variable(img)
- label = Variable(label)
- # optimizer.zero_grad()
- output = module(img)
- optimizer.zero_grad()
- for i in range(40):
- loss = loss_list(output,label[:,i])
- loss.backward()
- _,predict = torch.max(output,1)
- correct_num = sum(predict==label[:,i])
- all_correct_num += correct_num.data[0]
- optimizer.step()
-
-
- Accuracy = all_correct_num *1.0/(len(train_dataset)*40.0)
- print('Epoch ={0},all_correct_num={1},Accuracy={2}'.format(Epoch,all_correct_num,Accuracy))
-
- torch.save(module,'W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')
- '''''
- test_txt = 'W:/pic_data/face/CelebA/Img/test1000.txt'
- test_dataset = myDataset(img_dir=img_root,img_txt = test_txt,transform= transform)
- test_dataloader = Data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
-
- module = torch.load('W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')
- module.eval()
- all_correct_num = 0
- for ii,(img,label) in enumerate(test_dataloader):
-
- img = Variable(img)
- label = Variable(label)
- output = module(img)
- for i in range(40):
- _,predict = torch.max(output,1)
- correct_num = sum(predict==label[:,i])
- all_correct_num += correct_num.data[0]
- Accuracy = all_correct_num *1.0/(len(test_dataset)*40.0)
- print('all_correct_num={0},Accuracy={1}'.format(all_correct_num,Accuracy))
- '''
作者: 梦缠绕的时候 时间: 2018-7-5 10:19
作者: 吴琼老师 时间: 2018-7-5 16:41
欢迎光临 黑马程序员技术交流社区 (http://bbs.itheima.com/) |
黑马程序员IT技术论坛 X3.2 |