A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

介绍
最近在做模型的量化,量化的模型是人脸检测网络mtcnn,我从Onet开始入手,原先这个模型使用的权重文件是ckpt,这种存储格式适合训练,如果要做量化的话,需要先转化为pb文件,把其中的变量都持久化。再进一步做量化
生成的思路是给加载ckpt文件的onet网络导入一张48x48的人头图像,输出softmax值和box数值,再把网络加载方式换成生成的pb文件,再送一样的一幅图进去,查看输出结果,一样则转化成功。然后接下来就可以在生成的pb文件上做int8量化。

pb文件
pb是protocol(协议) buffer(缓冲)的缩写。TensorFlow训练模型后存成的pb文件,是一种表示模型(神经网络)结构的二进制文件,不带有源代码。
pb文件中可以只存参数,也可以存参数加网络结构,我们这里要生成的是存参数+网络结构,这样在推断的时候,可以不用重新在代码中定义网络结构,直接送入图像就可以输出结果,很方便。google现在也推荐这种文件格式。

把模型保存成pb文件
我们在原网络中加载ckpt模型,然后回复成sess,再从sess保存到pb文件
代码如下:

import sys
import argparse
import time
import os  
os.environ['CUDA_VISIBLE_DEVICES']='3'
import tensorflow as tf
import cv2
import numpy as np
from tensorflow.python.framework import graph_util
from src.mtcnn import PNet, RNet, ONet
from tools import detect_face, get_model_filenames

def main(args):
    out_pb_path="onet_trained2.pb"
    img = cv2.imread(args.image_path)
    img48 = (img - 127.5) * (1. / 128.0)
    img_x = np.expand_dims(img48, 0)
    file_paths = get_model_filenames(args.model_dir)
    with tf.device('/gpu:3'):
        with tf.Graph().as_default():
            config = tf.ConfigProto(allow_soft_placement=True)
                # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
            with tf.Session(config=config) as sess:
                if len(file_paths) == 3:
                    image_onet = tf.placeholder(tf.float32, [None, 48, 48, 3])
                    onet = ONet({'data': image_onet}, mode='test')
                    out_tensor_onet = onet.get_all_output()
                    saver_onet = tf.train.Saver(
                                    [v for v in tf.global_variables()
                                     if v.name[0:5] == "onet/"])
                    saver_onet.restore(sess, file_paths[2])
                    sess.run(out_tensor_onet, feed_dict={image_onet: img_x})
                    graph = tf.get_default_graph() # 获得默认的图
                    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
#                    for op in graph.get_operations():
#                        print(op.name, op.values())

                    output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
                                       sess = sess,
                                       input_graph_def = input_graph_def,# 等于:sess.graph_def input_graph_def
                                       output_node_names = ['softmax/softmax','onet/conv6-2/onet/conv6-2'])# 如果有多个输出节点,以逗号隔开
                    with tf.gfile.GFile(out_pb_path, "wb") as f: #保存模型
                        f.write(output_graph_def.SerializeToString()) #序列化输出
                    print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

代码中最关键的是输出节点名称的确定,只要写对了程序基本没有问题,我在这一块卡了好久。查节点的方法有直接看原网络的输出节点名称、可视化工具tensorflow、netron。我使用的是netron,很方便,在网页中上次模型文件即可


1 个回复

正序浏览
奈斯
回复 使用道具 举报
您需要登录后才可以回帖 登录 | 加入黑马