一、TFLite转换
1、运行环境
2、转换器功能说明
- 支持TF 1.X、TF 2.X版本的SavedModel、Keras、concrete function
- 支持TF 1.X Frozen Graph
tf.compat.v1.lite.TFLiteConverter - 支持量化感知训练(QAT)模型转换,待转换的模型格式
Keras - 支持训练后训练转换:动态量化、全整量化、float16量化、16x8量化
QAT版本要求:TensorFlow Version 2.3.0,TensorFlow-Model-Optimization Version 0.4.1; 训练后量化16x8模式,实验性16-bit激活+8-bit权重量化,可能暂时还不支持部署;
参考链接
[1]. TFLite 8-bit量化规范参考
[2]. TFLite指南
[3]. TensorFlow 模型优化指南
3、转换器脚本说明
## 传参说明--SavedModel # SavedModel Path--KerasModel # Keras Model(.h5) Path--ConcreteFuncs # concrete function 目前还没使用过...--FrozenModel # Frozen Graph Model Path--post_training_quantization # 训练后量化模式,‘dynamic_range’、‘full_integer’、‘float16’、‘16x8’;默认False模式,执行Float32格式转换--QAT # 待转换的模型是否是QAT模型,默认False--input_shapes # 列表,输入张量维度;转换SavedModel时若需要指定输入向量维度,需要此参数;转换FrozenModel时,需要此参数;默认None--data_path # 全整量化模式下,representative_dataset path,参考数量100-1000--network_shape # 列表,模型维度--input_arrays # 输入张量结点名称,默认None,转换FrozenModel时需要--output_arrays # 输出张量结点名称,默认None,转换FrozenModel时需要--tflite_path # tflite模型文件名## 函数说明# 转换函数def convert_func(self):# SavedModelif self.SavedModel:if not self.input_shapes:# 不支持指定输入向量维度converter = tf.lite.TFLiteConverter.from_saved_model(self.SavedModel)else:model = tf.saved_model.load(self.SavedModel)# DEFAULT_SERVING_SIGNATURE_DEF_KEY = 'serving_default'concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]concrete_func.inputs[0].set_shape([self.input_shapes[0], self.input_shapes[2],self.input_shapes[3], self.input_shapes[1]])converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])# KerasModelif self.KerasModel:if not self.QAT:model = tf.keras.models.load_model(self.KerasModel)converter = tf.lite.TFLiteConverter.from_keras_model(model)else:# 转换QAT模型,需要添加tfmot.quantization.keras.quantize_scope(),用于解决QAT模型with tfmot.quantization.keras.quantize_scope():model = tf.keras.models.load_model(self.KerasModel)converter = tf.lite.TFLiteConverter.from_keras_model(model)# ConcreteFuncsif self.ConcreteFuncs:# 目前仅支持每次调用时仅接受一个concrete functionconverter = tf.lite.TFLiteConverter.from_concrete_functions([self.ConcreteFuncs])# FrozenModelif self.FrozenModel:# Converting a GraphDef from file.converter = tf.compat.v1.TFLiteConverter.from_frozen_graph(self.FrozenModel,self.input_arrays,self.output_arrays,self.input_shapes)# 量化处理def quan_process(self, converter):# representative_dataset数据预处理部分,需要与模型预处理方法一致def pre_process(self, img_path):
二、TFLite部署
1、环境配置
- Bazel-2.0.0安装步骤
- Android NDK r21b
- TensorFlow-2.3.0源码
sudo install adb
2、步骤
2.1 编写TFLite Inference C++代码
参考Tflitewrapper代码。TFLite Inference API参考网址,python版仅用于验证,TensorFlow源码例程TensorFlow-2.3.0/tensorflow/lite/examples/label_image
2.2 Bazel编译
- Ubuntu本地编译
- Arm64编译
- Arm32编译
2.3 Adb push
# Enter Phoneadb rootadb shell# push 手机端adb push xx.tflite /data/local/tmpadb push demo /data/local/tmpadb push libTfliteWrapper.so /data/local/tmp
2.4 Execution
# 路径export path# 执行./demo
2.5 adb pull
# 将输出pull 本地adb pull xx path
