环境要求
支持TensorFlow1.X,在TensorFlow1.X最后一个版本1.15.3上测试通过;TensorFlow2.0暂未测试
大致算法原理
仅对权值进行int8量化,一般权值量化并不需要finetune,直接用MNN的转换工具的“—weightQuantBits”进行转换即可,但也可使用本工具进行测试精度,或者finetune到更低的bit数;可与剪枝工具叠加使用;
支持的op,使用建议
权值量化8bit情况下一般不会损失精度,不需要训练,而模型参数压缩4倍左右,推理速度和float一致
- 目前支持Conv2D, DepthwiseConv2dNative,带参数MatMul
-
使用方法
读取已经训练好的float模型的checkpoint,然后插入权值量化节点,注意添加第22行代码,示例代码: ```python from mnncompress.tensorflow.weight_quantizer import WeightQuantizer, strip_wq_ops
构建 前向 网络模型
build_model_architecture()
在定义反向计算之前构建weight quantizer,向图中插入相关节点
graph = tf.get_default_graph() weight_quantizer = WeightQuantizer(graph, bits=4, debug_info=True)
定义optimizer
opt = …
恢复原模型中的变量
saver = tf.train.Saver()
sess = tf.Session()
给原模型中的变量用checkpoint赋值
saver.restore(sess, ‘save/model.ckpt’)
训练之前初始化相关变量,放在restore之后
weight_quantizer.init(sess)
2. 正常训练,**注意添加第5行代码**,示例代码:
```python
for data in training_dataset:
feed_dict = {x: batch_x, labels: batch_y}
sess.run([train_step], feed_dict=feed_dict)
weight_quantizer.update(sess)
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=output_names) f = open(“frozen.pb”, ‘wb’) f.write(output_graph_def.SerializeToString())
4. 使用MNN转换工具的“--weightQuantBits numBits”选项将frozen pb转换成MNN模型,其中numBits为WeightQuantizer中的bit数,得到的MNN模型的精度和frozen.pb一致
```bash
mnnconvert --modelFile frozen.pb --MNNModel weight_quant_model.mnn --framework TF --bizCode MNNTest --compressionParams compress_params.bin --weightQuantBits 8
相关API
WeightQuantizer
WeightQuantizer(graph, bits=8, debug_info=False)
参数
graph: tensorflow模型图
bits: 权值量化的bit数
debug_info: bool,是否输出debug信息
方法和属性
init(sess): 训练之前初始化内部用到的相关变量
update(sess): 更新内部状态
save_compress_params(filename, append=False):
用于保存MNN转换时需要用的模型压缩信息
filename:str,MNN模型压缩参数将保存到这个文件名指定的文件中
append:bool,是否将量化参数追加到filename文件中。如果进行量化的模型有剪枝,请将剪枝时通过save_compress_params生成的剪枝信息文件通过此参数传入,并将 append 设置为True
strip_wq_ops()
去掉权值量化相关算子