环境要求

  1. 支持TensorFlow1.X,在TensorFlow1.X最后一个版本1.15.3上测试通过;TensorFlow2.0暂未测试

    大致算法原理

    仅对权值进行int8量化,一般权值量化并不需要finetune,直接用MNN的转换工具的“—weightQuantBits”进行转换即可,但也可使用本工具进行测试精度,或者finetune到更低的bit数;可与剪枝工具叠加使用;

    支持的op,使用建议

  2. 权值量化8bit情况下一般不会损失精度,不需要训练,而模型参数压缩4倍左右,推理速度和float一致

  3. 目前支持Conv2D, DepthwiseConv2dNative,带参数MatMul
  4. 可以结合剪枝一起使用,以获得更大的压缩倍数

    使用方法

  5. 读取已经训练好的float模型的checkpoint,然后插入权值量化节点,注意添加第22行代码,示例代码: ```python from mnncompress.tensorflow.weight_quantizer import WeightQuantizer, strip_wq_ops

构建 前向 网络模型

build_model_architecture()

在定义反向计算之前构建weight quantizer,向图中插入相关节点

graph = tf.get_default_graph() weight_quantizer = WeightQuantizer(graph, bits=4, debug_info=True)

定义optimizer

opt = …

恢复原模型中的变量

saver = tf.train.Saver()

sess = tf.Session()

给原模型中的变量用checkpoint赋值

saver.restore(sess, ‘save/model.ckpt’)

训练之前初始化相关变量,放在restore之后

weight_quantizer.init(sess)

  1. 2. 正常训练,**注意添加第5行代码**,示例代码:
  2. ```python
  3. for data in training_dataset:
  4. feed_dict = {x: batch_x, labels: batch_y}
  5. sess.run([train_step], feed_dict=feed_dict)
  6. weight_quantizer.update(sess)
  1. 训练完成,去掉插入的权值量化算子(第2行代码),保存frozen pb,示例代码: ```python

    去掉插入的权值量化算子

    strip_wq_ops()

output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=output_names) f = open(“frozen.pb”, ‘wb’) f.write(output_graph_def.SerializeToString())

  1. 4. 使用MNN转换工具的“--weightQuantBits numBits”选项将frozen pb转换成MNN模型,其中numBitsWeightQuantizer中的bit数,得到的MNN模型的精度和frozen.pb一致
  2. ```bash
  3. mnnconvert --modelFile frozen.pb --MNNModel weight_quant_model.mnn --framework TF --bizCode MNNTest --compressionParams compress_params.bin --weightQuantBits 8

相关API

WeightQuantizer

  1. WeightQuantizer(graph, bits=8, debug_info=False)
  1. 参数
  2. graph: tensorflow模型图
  3. bits: 权值量化的bit
  4. debug_info: bool,是否输出debug信息
  1. 方法和属性
  2. init(sess): 训练之前初始化内部用到的相关变量
  3. update(sess): 更新内部状态
  4. save_compress_params(filename, append=False):
  5. 用于保存MNN转换时需要用的模型压缩信息
  6. filenamestrMNN模型压缩参数将保存到这个文件名指定的文件中
  7. appendbool,是否将量化参数追加到filename文件中。如果进行量化的模型有剪枝,请将剪枝时通过save_compress_params生成的剪枝信息文件通过此参数传入,并将 append 设置为True

strip_wq_ops()

  1. 去掉权值量化相关算子