A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

    深度学习中,有时我们需要对数据集进行预处理,这样能够更好的读取数据。

一、png格式生成.npy格式
import numpy as np
import os
from PIL import Image

dir="C:/Users/Administrator/Desktop/trainA"

def getFileArr(dir):
    result_arr=[]
    label_list=[]
    map={}
    map_file_result={}
    map_file_label={}
    map_new={}
    count_label=0
    count=0

    file_list=os.listdir(dir)
    for file in file_list:
        file_path=os.path.join(dir,file)

        label=file.split(".")[0].split("_")[0]
        map[file]=label
        if label not in label_list:
            label_list.append(label)
            map_new[label]=count_label
            count_label=count_label+1
        img=Image.open(file_path)
        result=np.array([])
        r,g,b=img.split()

        r_arr=np.array(r).reshape(4096)
        g_arr=np.array(g).reshape(4096)
        b_arr=np.array(b).reshape(4096)
        img_arr=np.concatenate((r_arr,g_arr,b_arr))
        result=np.concatenate((result,img_arr))
        result=result.reshape((64,64,3))
        result=result/255.0
        map_file_result[file]=result
        result_arr.append(result)
        count=count+1
    for file in file_list:
        map_file_label[file]=map_new[map[file]]
        #map[file]=map_new[map[file]]

    ret_arr=[]
    for file in file_list:
        each_list=[]
        label_one_zero=np.zeros(count_label)
        result=map_file_result[file]
        label=map_file_label[file]
        label_one_zero[label]=1.0
        #print(label_one_zero)
        each_list.append(result)
        each_list.append(label_one_zero)
        ret_arr.append(each_list)
    os.makedirs("C:/Users/Administrator/Desktop/npy")
    np.save('C:/Users/Administrator/Desktop/npy/test_data.npy', ret_arr)
    return ret_arr
if __name__=="__main__":
    ret_arr=getFileArr(dir)
二、.npy格式生成png格式
import numpy as np
from PIL import Image
import os

dir="C:/Users/Administrator/Desktop/npy/"#npy文件路径
dest_dir="C:/Users/Administrator/Desktop/train/"
def npy2jpg(dir,dest_dir):
    if os.path.exists(dir)==False:
        os.makedirs(dir)
    if os.path.exists(dest_dir)==False:
        os.makedirs(dest_dir)
    file=dir+'test_data.npy'
    con_arr=np.load(file)
    count=0
    for con in con_arr:
        arr=con[0]
        label=con[1]
        print(np.argmax(label))
        arr=arr*255
        #arr=np.transpose(arr,(2,1,0))
        arr=np.reshape(arr,(3,64,64))
        r=Image.fromarray(arr[0]).convert("L")
        g=Image.fromarray(arr[1]).convert("L")
        b=Image.fromarray(arr[2]).convert("L")

        img=Image.merge("RGB",(r,g,b))

        label_index=np.argmax(label)
        img.save(dest_dir+str(label_index)+"_"+str(count)+".png")
        count=count+1

if __name__=="__main__":
    npy2jpg(dir,dest_dir)
三、注意
             根据自己的数据集需要改尺寸和维度以及改路径。
---------------------
作者:蹦跶的小羊羔
来源:CSDN
原文:https://blog.csdn.net/yql_617540298/article/details/82747789
版权声明:本文为博主原创文章,转载请附上博文链接!

2 个回复

倒序浏览
回复 使用道具 举报
回复 使用道具 举报
您需要登录后才可以回帖 登录 | 加入黑马