A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

本帖最后由 梦缠绕的时候 于 2018-6-26 09:34 编辑

数据来源:CalebA人脸数据集(官网链接)是香港中文大学的开放数据,包含10,177个名人身份的202,599张人脸图片,并且都做好了特征标记,这对人脸相关的训练是非常好用的数据集。共计40个特征,具体是哪些特征,可以去官网查询。话不多说,直接开始流程。

整个流程可以分为大致以下几个步骤:

1.图片预处理

2.构建网络

3.训练

4.测试

5.优化

一。图片加载,以为源数据没有经过处理,我们要重写torch.utils.data.Dataloader()处理图片,然后才能将图片用于加载。代码如下:




  • def default_loader(path):  
  •     try:  
  •         img = Image.open(path)  
  •         return img.convert('RGB')  
  •     except:  
  •         print("Can not open {0}".format(path))  
  • class myDataset(Data.DataLoader):  
  •     def __init__(self,img_dir,img_txt=img_txt,transform=None,loader=default_loader):  
  •         img_list = []  
  •         img_labels = []  
  •          
  •         fp = open(img_txt,'r')  
  •         for line in fp.readlines():  
  •             if len(line.split())!=41:  
  •                 continue  
  •             img_list.append(line.split()[0])  
  •             img_label_single = []  
  •             for value in line.split()[1:]:  
  •                 if value == '-1':  
  •                     img_label_single.append(0)  
  •                 if value == '1':  
  •                     img_label_single.append(1)  
  •             img_labels.append(img_label_single)  
  •         self.imgs = [os.path.join(img_dir,file) for file in img_list]  
  •         self.labels = img_labels  
  •         self.transform = transform   
  •         self.loader = loader   
  •     def __len__(self):  
  •         return len(self.imgs)  
  •     def __getitem__(self,index):  
  •         img_path = self.imgs[index]  
  •         label = torch.from_numpy(np.array(self.labels[index],dtype=np.int64))  
  •         img = self.loader(img_path)  
  •         if self.transform is not None:  
  •             try:  
  •                 img = self.transform(img)  
  •             except:  
  •                 print('Cannot transform image: {}'.format(img_path))  
  •         return img,label  

图片增强、归一化处理和加载:




  • transform = transforms.Compose([  
  •                                 transforms.Resize(40),  
  •                                 transforms.CenterCrop(32),  
  •                                 transforms.RandomHorizontalFlip(),  
  •                                 transforms.ToTensor(),  
  •                                 transforms.Normalize(mean=[0.5,0.5,0.5],  
  •                                                      std = [0.5,0.5,0.5])  
  •                                 ])  

[python] view plain copy



  • #训练集                                 

train_dataset = myDataset(img_dir=img_root,img_txt=train_txt,transform= transform)train_dataloader = Data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)



  • #测试集  





  • test_dataset = myDataset(img_dir=img_root,img_txt = test_txt,transform= transform)  
  • test_dataloader = Data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)  

构建网络:我使用的网络结构是每种属性使用3层卷积加上3层fc。网络结构比较简单,导致准确率不会有太高的表现,如果有兴趣可以做下优化,文末有优化的思路供大家讨论。好了,先上代码:




  • def make_conv():  
  •     return nn.Sequential(  
  •             nn.Conv2d(3,16,3,1,1),  
  •             nn.ReLU(),  
  •             nn.MaxPool2d(2),  
  •             nn.Conv2d(16,32,3,1,1),  
  •             nn.ReLU(),  
  •             nn.MaxPool2d(2),  
  •             nn.Conv2d(32,64,3,1,1),  
  •             nn.ReLU(),  
  •             #nn.Dropout(0.5),  
  •             nn.MaxPool2d(2)  
  •             )  
  • def make_fc():  
  •     return nn.Sequential(  
  •             nn.Linear(64*4*4,128),  
  •             nn.ReLU(),  
  •             #nn.Dropout(0.5),  
  •             nn.Linear(128,64),  
  •             nn.ReLU(),  
  •             nn.Dropout(0.5),#Dropout()可以一定程度上防止过拟合,放在不同位置或许会有意想不到的结果,有条件可以多尝试几次  
  •             nn.Linear(64,2)  
  •             )  
  • class face_attr(nn.Module):  
  •     def __init__(self):  
  •         super(face_attr,self).__init__()  
  •         #attr0  
  •         self.attr0_layer1 = make_conv()  
  •         self.attr0_layer2 = make_fc()  
  •         #attr1  
  •         self.attr1_layer1 = make_conv()  
  •         self.attr1_layer2 = make_fc()  
  •         ...#每一中属性的计算都是相同的,在文中省略,  
  •         #attr38  
  •         self.attr38_layer1 = make_conv()  
  •         self.attr38_layer2 = make_fc()  
  •         #attr39  
  •         self.attr39_layer1 = make_conv()  
  •         self.attr39_layer2 = make_fc()  
  •     def forward(self,x):  
  •         out_list = []  
  •         #out0  
  •         out0 = self.attr0_layer1(x)  
  •         out0 = out0.view(out0.size(0),-1)  
  •         out0 = self.attr0_layer2(out0)  
  •         out_list.append(out0)  
  •         ...  
  •         #out39  
  •         out39 = self.attr39_layer1(x)  
  •         out39 = out39.view(out39.size(0),-1)  
  •         out39 = self.attr39_layer2(out39)  
  •         out_list.append(out39)  
  •          
  •         return out_list  

接下来就可以开始训练网络了,定义优化器的时候可以设置一下weight_decay=1e-8,也可以在一定程度上防止过拟合。





  • module = face_attr()  
  • #print(module)  
  •   
  •   
  • optimizer = optim.Adam(module.parameters(),lr = 0.001,weight_decay=1e-8)  
  •   
  • loss_list = []  
  • for i in range(40):  
  •     loss_func = nn.CrossEntropyLoss()  
  •     loss_list.append(loss_func)  
  • #loss_func = nn.CrossEntropyLoss()  
  • for Epoch in range(50):  
  •     all_correct_num = 0  
  •     for ii,(img,label) in enumerate(train_dataloader):  
  •          
  •         img = Variable(img)  
  •         label = Variable(label)  
  •         output = module(img)  
  •         optimizer.zero_grad()  
  •         for i in range(40):  
  •             loss = loss_list(output,label[:,i])  
  •             loss.backward()  
  •             _,predict = torch.max(output,1)  
  •             correct_num = sum(predict==label[:,i])  
  •             all_correct_num += correct_num.data[0]  
  •         optimizer.step()         
  •          
  •       
  •     Accuracy =  all_correct_num *1.0/(len(train_dataset)*40.0)  
  •     print('Epoch ={0},all_correct_num={1},Accuracy={2}'.format(Epoch,all_correct_num,Accuracy))  
  •   
  •     torch.save(module,'W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')#每跑一个epoch就保存一次模型  


测试网络:和训练类似,只是不用优化和做反向传播。





  • module = torch.load('W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')#加载刚刚保存的网络  
  • module.eval()#改成测试模式  
  • all_correct_num = 0  
  • for ii,(img,label) in enumerate(test_dataloader):  
  •          
  •     img = Variable(img)  
  •     label = Variable(label)  
  •     output = module(img)  
  •     for i in range(40):  
  •         _,predict = torch.max(output,1)  
  •         correct_num = sum(predict==label[:,i])  
  •         all_correct_num += correct_num.data[0]            
  • Accuracy =  all_correct_num *1.0/(len(test_dataset)*40.0)  
  • print('all_correct_num={0},Accuracy={1}'.format(all_correct_num,Accuracy))  


总结:我因为是笔记本电脑,没有GPU,所以只给了大约5000个数据用于训练(即使是5000个数据我的电脑也跑了2天才跑完50个epoch),1000个数据用于测试。测试的准确率在90%左右。有条件的同学可以做些优化,下面提供一些可以优化的方面:

1.图片增强,我因为电脑不给力,无法处理较大的数据,所以将原始图片缩放到40*40,然后截取了32*32作为输入,如果有GPU等条件,可以考虑128*128输入

2.最开始没加Dropout()的时候出现了过拟合的情况,当然这也和训练集较小有关系,建议训练集给到60000个样本。上述提到的Dropout()多尝试几个位置,我是放在了最后一层输出之前。期待大家尝试之后分享一下结果。

3.我现在每种属性都是使用相同的网络,但是这种网络有可能不是对于每种属性都是最优的选择,可以针对每一种属性单独写一层网络。例如:某个属性使用全fc层就可以达到很高的准确率,某个属性或许需要4层卷积+2层fc可以达到很好的效果,这种情况只有靠多尝试,算法什么的可能并不能给出哪种才是最适合的网络模型。建议输出每种属性的准确率,然后针对准确率较低的属性做相应的网络优化。

4.多准备几块GPU做训练和测试吧,没GPU真不给力。希望有小伙伴在这个网络的基础上能达到更高的准确率。

附上整体代码如下:





  • # -*- coding: utf-8 -*-  
  • """
  • Created on Sun Jun 17 11:54:36 2018
  • @author: sky-hole
  • """  
  •   
  • import torch  
  • import torch.nn as nn  
  • from torch.autograd import Variable  
  • import torch.optim as optim  
  • import torchvision.transforms as transforms  
  • import torch.utils.data as Data  
  • from PIL import Image  
  • import numpy as np  
  • import os  
  •   
  • img_root = 'W:/pic_data/face/CelebA/Img/img_align_celeba'  
  • train_txt = 'W:/pic_data/face/CelebA/Img/train10000.txt'  
  • batch_size = 2  
  •   
  • def default_loader(path):  
  •     try:  
  •         img = Image.open(path)  
  •         return img.convert('RGB')  
  •     except:  
  •         print("Can not open {0}".format(path))  
  • class myDataset(Data.DataLoader):  
  •     def __init__(self,img_dir,img_txt,transform=None,loader=default_loader):  
  •         img_list = []  
  •         img_labels = []  
  •          
  •         fp = open(img_txt,'r')  
  •         for line in fp.readlines():  
  •             if len(line.split())!=41:  
  •                 continue  
  •             img_list.append(line.split()[0])  
  •             img_label_single = []  
  •             for value in line.split()[1:]:  
  •                 if value == '-1':  
  •                     img_label_single.append(0)  
  •                 if value == '1':  
  •                     img_label_single.append(1)  
  •             img_labels.append(img_label_single)  
  •         self.imgs = [os.path.join(img_dir,file) for file in img_list]  
  •         self.labels = img_labels  
  •         self.transform = transform   
  •         self.loader = loader   
  •     def __len__(self):  
  •         return len(self.imgs)  
  •     def __getitem__(self,index):  
  •         img_path = self.imgs[index]  
  •         label = torch.from_numpy(np.array(self.labels[index],dtype=np.int64))  
  •         img = self.loader(img_path)  
  •         if self.transform is not None:  
  •             try:  
  •                 img = self.transform(img)  
  •             except:  
  •                 print('Cannot transform image: {}'.format(img_path))  
  •         return img,label  
  • transform = transforms.Compose([  
  •                                 transforms.Resize(40),  
  •                                 transforms.CenterCrop(32),  
  •                                 transforms.RandomHorizontalFlip(),  
  •                                 transforms.ToTensor(),  
  •                                 transforms.Normalize(mean=[0.5,0.5,0.5],  
  •                                                      std = [0.5,0.5,0.5])  
  •                                 ])  
  •                                  
  • train_dataset = myDataset(img_dir=img_root,img_txt=train_txt,transform= transform)  
  • train_dataloader = Data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)  
  •   
  • #print(len(train_dataset))  
  • #print(len(train_dataloader))  
  • def make_conv():  
  •     return nn.Sequential(  
  •             nn.Conv2d(3,16,3,1,1),  
  •             nn.ReLU(),  
  •             nn.MaxPool2d(2),  
  •             nn.Conv2d(16,32,3,1,1),  
  •             nn.ReLU(),  
  •             nn.MaxPool2d(2),  
  •             nn.Conv2d(32,64,3,1,1),  
  •             nn.ReLU(),  
  •             #nn.Dropout(0.5),  
  •             nn.MaxPool2d(2)  
  •             )  
  • def make_fc():  
  •     return nn.Sequential(  
  •             nn.Linear(64*4*4,128),  
  •             nn.ReLU(),  
  •             #nn.Dropout(0.5),  
  •             nn.Linear(128,64),  
  •             nn.ReLU(),  
  •             nn.Dropout(0.5),  
  •             nn.Linear(64,2)  
  •             )  
  • class face_attr(nn.Module):  
  •     def __init__(self):  
  •         super(face_attr,self).__init__()  
  •         #attr0  
  •         self.attr0_layer1 = make_conv()  
  •         self.attr0_layer2 = make_fc()  
  •         #attr1  
  •         self.attr1_layer1 = make_conv()  
  •         self.attr1_layer2 = make_fc()  
  •         #attr2  
  •         self.attr2_layer1 = make_conv()  
  •         self.attr2_layer2 = make_fc()  
  •         #attr3  
  •         self.attr3_layer1 = make_conv()  
  •         self.attr3_layer2 = make_fc()  
  •         #attr4  
  •         self.attr4_layer1 = make_conv()  
  •         self.attr4_layer2 = make_fc()  
  •         #attr5  
  •         self.attr5_layer1 = make_conv()  
  •         self.attr5_layer2 = make_fc()  
  •         #attr6  
  •         self.attr6_layer1 = make_conv()  
  •         self.attr6_layer2 = make_fc()  
  •         #attr7  
  •         self.attr7_layer1 = make_conv()  
  •         self.attr7_layer2 = make_fc()  
  •         #attr8  
  •         self.attr8_layer1 = make_conv()  
  •         self.attr8_layer2 = make_fc()  
  •         #attr9  
  •         self.attr9_layer1 = make_conv()  
  •         self.attr9_layer2 = make_fc()  
  •         #attr10  
  •         self.attr10_layer1 = make_conv()  
  •         self.attr10_layer2 = make_fc()  
  •         #attr11  
  •         self.attr11_layer1 = make_conv()  
  •         self.attr11_layer2 = make_fc()  
  •         #attr12  
  •         self.attr12_layer1 = make_conv()  
  •         self.attr12_layer2 = make_fc()  
  •         #attr13  
  •         self.attr13_layer1 = make_conv()  
  •         self.attr13_layer2 = make_fc()  
  •         #attr14  
  •         self.attr14_layer1 = make_conv()  
  •         self.attr14_layer2 = make_fc()  
  •         #attr15  
  •         self.attr15_layer1 = make_conv()  
  •         self.attr15_layer2 = make_fc()  
  •         #attr16  
  •         self.attr16_layer1 = make_conv()  
  •         self.attr16_layer2 = make_fc()  
  •         #attr17  
  •         self.attr17_layer1 = make_conv()  
  •         self.attr17_layer2 = make_fc()  
  •         #attr18  
  •         self.attr18_layer1 = make_conv()  
  •         self.attr18_layer2 = make_fc()  
  •         #attr19  
  •         self.attr19_layer1 = make_conv()  
  •         self.attr19_layer2 = make_fc()  
  •         #attr20  
  •         self.attr20_layer1 = make_conv()  
  •         self.attr20_layer2 = make_fc()  
  •         #attr21  
  •         self.attr21_layer1 = make_conv()  
  •         self.attr21_layer2 = make_fc()  
  •         #attr22  
  •         self.attr22_layer1 = make_conv()  
  •         self.attr22_layer2 = make_fc()  
  •         #attr23  
  •         self.attr23_layer1 = make_conv()  
  •         self.attr23_layer2 = make_fc()  
  •         #attr24  
  •         self.attr24_layer1 = make_conv()  
  •         self.attr24_layer2 = make_fc()  
  •         #attr25  
  •         self.attr25_layer1 = make_conv()  
  •         self.attr25_layer2 = make_fc()  
  •         #attr26  
  •         self.attr26_layer1 = make_conv()  
  •         self.attr26_layer2 = make_fc()  
  •         #attr27  
  •         self.attr27_layer1 = make_conv()  
  •         self.attr27_layer2 = make_fc()  
  •         #attr28  
  •         self.attr28_layer1 = make_conv()  
  •         self.attr28_layer2 = make_fc()  
  •         #attr29  
  •         self.attr29_layer1 = make_conv()  
  •         self.attr29_layer2 = make_fc()  
  •         #attr30  
  •         self.attr30_layer1 = make_conv()  
  •         self.attr30_layer2 = make_fc()  
  •         #attr31  
  •         self.attr31_layer1 = make_conv()  
  •         self.attr31_layer2 = make_fc()  
  •         #attr32  
  •         self.attr32_layer1 = make_conv()  
  •         self.attr32_layer2 = make_fc()  
  •         #attr33  
  •         self.attr33_layer1 = make_conv()  
  •         self.attr33_layer2 = make_fc()  
  •         #attr34  
  •         self.attr34_layer1 = make_conv()  
  •         self.attr34_layer2 = make_fc()  
  •         #attr35  
  •         self.attr35_layer1 = make_conv()  
  •         self.attr35_layer2 = make_fc()  
  •         #attr36  
  •         self.attr36_layer1 = make_conv()  
  •         self.attr36_layer2 = make_fc()  
  •         #attr37  
  •         self.attr37_layer1 = make_conv()  
  •         self.attr37_layer2 = make_fc()  
  •         #attr38  
  •         self.attr38_layer1 = make_conv()  
  •         self.attr38_layer2 = make_fc()  
  •         #attr39  
  •         self.attr39_layer1 = make_conv()  
  •         self.attr39_layer2 = make_fc()  
  •     def forward(self,x):  
  •         out_list = []  
  •         #out0  
  •         out0 = self.attr0_layer1(x)  
  •         out0 = out0.view(out0.size(0),-1)  
  •         out0 = self.attr0_layer2(out0)  
  •         out_list.append(out0)  
  •         #out1  
  •         out1 = self.attr1_layer1(x)  
  •         out1 = out1.view(out1.size(0),-1)  
  •         out1 = self.attr1_layer2(out1)  
  •         out_list.append(out1)  
  •         #out2  
  •         out2 = self.attr2_layer1(x)  
  •         out2 = out2.view(out2.size(0),-1)  
  •         out2 = self.attr2_layer2(out2)  
  •         out_list.append(out2)  
  •         #out3  
  •         out3 = self.attr3_layer1(x)  
  •         out3 = out3.view(out3.size(0),-1)  
  •         out3 = self.attr3_layer2(out3)  
  •         out_list.append(out3)  
  •         #out4  
  •         out4 = self.attr4_layer1(x)  
  •         out4 = out4.view(out4.size(0),-1)  
  •         out4 = self.attr4_layer2(out4)  
  •         out_list.append(out4)  
  •         #out5  
  •         out5 = self.attr5_layer1(x)  
  •         out5 = out5.view(out5.size(0),-1)  
  •         out5 = self.attr5_layer2(out5)  
  •         out_list.append(out5)  
  •         #out6  
  •         out6 = self.attr6_layer1(x)  
  •         out6 = out6.view(out6.size(0),-1)  
  •         out6 = self.attr6_layer2(out6)  
  •         out_list.append(out6)  
  •         #out7  
  •         out7 = self.attr7_layer1(x)  
  •         out7 = out7.view(out7.size(0),-1)  
  •         out7 = self.attr7_layer2(out7)  
  •         out_list.append(out7)  
  •         #out8  
  •         out8 = self.attr8_layer1(x)  
  •         out8 = out8.view(out8.size(0),-1)  
  •         out8 = self.attr8_layer2(out8)  
  •         out_list.append(out8)  
  •         #out9  
  •         out9 = self.attr9_layer1(x)  
  •         out9 = out9.view(out9.size(0),-1)  
  •         out9 = self.attr9_layer2(out9)  
  •         out_list.append(out9)  
  •         #out10  
  •         out10 = self.attr10_layer1(x)  
  •         out10 = out10.view(out10.size(0),-1)  
  •         out10 = self.attr10_layer2(out10)  
  •         out_list.append(out10)  
  •         #out11  
  •         out11 = self.attr11_layer1(x)  
  •         out11 = out11.view(out11.size(0),-1)  
  •         out11 = self.attr11_layer2(out11)  
  •         out_list.append(out11)  
  •         #out12  
  •         out12 = self.attr12_layer1(x)  
  •         out12 = out12.view(out12.size(0),-1)  
  •         out12 = self.attr12_layer2(out12)  
  •         out_list.append(out12)  
  •         #out13  
  •         out13 = self.attr13_layer1(x)  
  •         out13 = out13.view(out13.size(0),-1)  
  •         out13 = self.attr13_layer2(out13)  
  •         out_list.append(out13)  
  •         #out14  
  •         out14 = self.attr14_layer1(x)  
  •         out14 = out14.view(out14.size(0),-1)  
  •         out14 = self.attr14_layer2(out14)  
  •         out_list.append(out14)  
  •         #out15  
  •         out15 = self.attr15_layer1(x)  
  •         out15 = out15.view(out15.size(0),-1)  
  •         out15 = self.attr15_layer2(out15)  
  •         out_list.append(out15)  
  •         #out16  
  •         out16 = self.attr16_layer1(x)  
  •         out16 = out16.view(out16.size(0),-1)  
  •         out16 = self.attr16_layer2(out16)  
  •         out_list.append(out16)  
  •         #out17  
  •         out17 = self.attr17_layer1(x)  
  •         out17 = out17.view(out17.size(0),-1)  
  •         out17 = self.attr17_layer2(out17)  
  •         out_list.append(out17)  
  •         #out18  
  •         out18 = self.attr18_layer1(x)  
  •         out18 = out18.view(out18.size(0),-1)  
  •         out18 = self.attr18_layer2(out18)  
  •         out_list.append(out18)  
  •         #out19  
  •         out19 = self.attr19_layer1(x)  
  •         out19 = out19.view(out19.size(0),-1)  
  •         out19 = self.attr19_layer2(out19)  
  •         out_list.append(out19)  
  •         #out20  
  •         out20 = self.attr20_layer1(x)  
  •         out20 = out20.view(out20.size(0),-1)  
  •         out20 = self.attr20_layer2(out20)  
  •         out_list.append(out20)  
  •         #out21  
  •         out21 = self.attr21_layer1(x)  
  •         out21 = out21.view(out21.size(0),-1)  
  •         out21 = self.attr21_layer2(out21)  
  •         out_list.append(out21)  
  •         #out22  
  •         out22 = self.attr22_layer1(x)  
  •         out22 = out22.view(out22.size(0),-1)  
  •         out22 = self.attr22_layer2(out22)  
  •         out_list.append(out22)  
  •         #out23  
  •         out23 = self.attr23_layer1(x)  
  •         out23 = out23.view(out23.size(0),-1)  
  •         out23 = self.attr23_layer2(out23)  
  •         out_list.append(out23)  
  •         #out24  
  •         out24 = self.attr24_layer1(x)  
  •         out24 = out24.view(out24.size(0),-1)  
  •         out24 = self.attr24_layer2(out24)  
  •         out_list.append(out24)  
  •         #out25  
  •         out25 = self.attr25_layer1(x)  
  •         out25 = out25.view(out25.size(0),-1)  
  •         out25 = self.attr25_layer2(out25)  
  •         out_list.append(out25)  
  •         #out26  
  •         out26 = self.attr26_layer1(x)  
  •         out26 = out26.view(out26.size(0),-1)  
  •         out26 = self.attr26_layer2(out26)  
  •         out_list.append(out26)  
  •         #out27  
  •         out27 = self.attr27_layer1(x)  
  •         out27 = out27.view(out27.size(0),-1)  
  •         out27 = self.attr27_layer2(out27)  
  •         out_list.append(out27)  
  •         #out28  
  •         out28 = self.attr28_layer1(x)  
  •         out28 = out28.view(out28.size(0),-1)  
  •         out28 = self.attr28_layer2(out28)  
  •         out_list.append(out28)  
  •         #out29  
  •         out29 = self.attr29_layer1(x)  
  •         out29 = out29.view(out29.size(0),-1)  
  •         out29 = self.attr29_layer2(out29)  
  •         out_list.append(out29)  
  •         #out30  
  •         out30 = self.attr30_layer1(x)  
  •         out30 = out30.view(out30.size(0),-1)  
  •         out30 = self.attr30_layer2(out30)  
  •         out_list.append(out30)  
  •         #out31  
  •         out31 = self.attr31_layer1(x)  
  •         out31 = out31.view(out31.size(0),-1)  
  •         out31 = self.attr31_layer2(out31)  
  •         out_list.append(out31)  
  •         #out32  
  •         out32 = self.attr32_layer1(x)  
  •         out32 = out32.view(out32.size(0),-1)  
  •         out32 = self.attr32_layer2(out32)  
  •         out_list.append(out32)  
  •         #out33  
  •         out33 = self.attr33_layer1(x)  
  •         out33 = out33.view(out33.size(0),-1)  
  •         out33 = self.attr33_layer2(out33)  
  •         out_list.append(out33)  
  •         #out34  
  •         out34 = self.attr34_layer1(x)  
  •         out34 = out34.view(out34.size(0),-1)  
  •         out34 = self.attr34_layer2(out34)  
  •         out_list.append(out34)  
  •         #out35  
  •         out35 = self.attr35_layer1(x)  
  •         out35 = out35.view(out35.size(0),-1)  
  •         out35 = self.attr35_layer2(out35)  
  •         out_list.append(out35)  
  •         #out36  
  •         out36 = self.attr36_layer1(x)  
  •         out36 = out36.view(out36.size(0),-1)  
  •         out36 = self.attr36_layer2(out36)  
  •         out_list.append(out36)  
  •         #out37  
  •         out37 = self.attr37_layer1(x)  
  •         out37 = out37.view(out37.size(0),-1)  
  •         out37 = self.attr37_layer2(out37)  
  •         out_list.append(out37)  
  •         #out38  
  •         out38 = self.attr38_layer1(x)  
  •         out38 = out38.view(out38.size(0),-1)  
  •         out38 = self.attr38_layer2(out38)  
  •         out_list.append(out38)  
  •         #out39  
  •         out39 = self.attr39_layer1(x)  
  •         out39 = out39.view(out39.size(0),-1)  
  •         out39 = self.attr39_layer2(out39)  
  •         out_list.append(out39)  
  •          
  •         return out_list  
  •      
  • module = face_attr()  
  • #print(module)  
  •   
  •   
  • optimizer = optim.Adam(module.parameters(),lr = 0.001,weight_decay=1e-8)  
  •   
  • loss_list = []  
  • for i in range(40):  
  •     loss_func = nn.CrossEntropyLoss()  
  •     loss_list.append(loss_func)  
  • #loss_func = nn.CrossEntropyLoss()  
  • for Epoch in range(50):  
  •     all_correct_num = 0  
  •     for ii,(img,label) in enumerate(train_dataloader):  
  •          
  •         img = Variable(img)  
  •         label = Variable(label)  
  •     #    optimizer.zero_grad()  
  •         output = module(img)  
  •         optimizer.zero_grad()  
  •         for i in range(40):  
  •             loss = loss_list(output,label[:,i])  
  •             loss.backward()  
  •             _,predict = torch.max(output,1)  
  •             correct_num = sum(predict==label[:,i])  
  •             all_correct_num += correct_num.data[0]  
  •         optimizer.step()         
  •          
  •       
  •     Accuracy =  all_correct_num *1.0/(len(train_dataset)*40.0)  
  •     print('Epoch ={0},all_correct_num={1},Accuracy={2}'.format(Epoch,all_correct_num,Accuracy))  
  •   
  •     torch.save(module,'W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')  
  • '''''
  • test_txt = 'W:/pic_data/face/CelebA/Img/test1000.txt'
  • test_dataset = myDataset(img_dir=img_root,img_txt = test_txt,transform= transform)
  • test_dataloader = Data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
  • module = torch.load('W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')
  • module.eval()
  • all_correct_num = 0
  • for ii,(img,label) in enumerate(test_dataloader):
  •          
  •     img = Variable(img)
  •     label = Variable(label)
  •     output = module(img)
  •     for i in range(40):
  •         _,predict = torch.max(output,1)
  •         correct_num = sum(predict==label[:,i])
  •         all_correct_num += correct_num.data[0]           
  • Accuracy =  all_correct_num *1.0/(len(test_dataset)*40.0)
  • print('all_correct_num={0},Accuracy={1}'.format(all_correct_num,Accuracy))
  • '''   

2 个回复

倒序浏览
回复 使用道具 举报
回复 使用道具 举报
您需要登录后才可以回帖 登录 | 加入黑马