TFlite模型变换及使用
1. tflite模型转换
keras 模型转换
import tensorflow as tfconverter = tf.lite.TFLiteConverter.from_keras_model_file("keras_model.h5")tflite_model = converter.convert()open("converted_model.tflite", "wb").write(tflite_model)
SavedModel
import tensorflow as tfconverter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)tflite_model = converter.convert()open("converted_model.tflite", "wb").write(tflite_model)
PB:需要将meta和checkpoint通过freeze_graph输出为pb格式,再进行转化
import tensorflow as tfgraph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb"input_arrays = ["input"]output_arrays = ["MobilenetV1/Predictions/Softmax"]converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)tflite_model = converter.convert()open("converted_model.tflite", "wb").write(tflite_model)
2. tflite轻量化模型推理
导入包:
import tflite_runtime.interpreter as tflite
这里使用tflite_runtime 库来实现模型的推理,该库独立与tensorflow可以单独运行,截至20210826官网上提供包在mac和linux上支持到py3.8,而win10下仅支持到python3.7。此外要成功打包需要对interpreter.py中函数进行修改。
if not __file__.endswith(os.path.join('tflite_runtime', 'interpreter.py')):# This file is part of tensorflow package.# from tensorflow.python.util.lazy_loader import LazyLoader# from tensorflow.python.util.tf_export import tf_export as _tf_export# Lazy load since some of the performance benchmark skylark rules# break dependencies. Must use double quotes to match code internal rewrite# rule.# pylint: disable=g-inconsistent-quotes# _interpreter_wrapper = LazyLoader(# "_interpreter_wrapper", globals(),# "tensorflow.lite.python.interpreter_wrapper."# "tensorflow_wrap_interpreter_wrapper")# pylint: enable=g-inconsistent-quotes
这里调用了tensorflow函数在程序打包后会提示找不到tensorflow依赖,修改后如下:
from tflite_runtime import tensorflow_wrap_interpreter_wrapper as _interpreter_wrapperdef _tf_export(*x, **kwargs):del x, kwargsreturn lambda x: x
本质上是将else中的部分在if下复制一份,让这两个导入的函数均来自tflite_runtime而不是tensorflow。
构建模型类如下:
class ModelProcess:def __init__(self,model_path=None) -> None:self.interpreter=Noneself.result=Noneself.__preProcess(model_path)def __preProcess(self,path=None):if path==None:# 这里的操作是为了在将模型打包进程序后,能够从系统临时目录中加载模型,详细说明见pyinstaller --add-datafilename = self.resource_path(os.path.join("model_dir","full_model.tflite"))self.interpreter=tflite.Interpreter(model_path=filename)else:self.interpreter=tflite.Interpreter(model_path=path)self.interpreter.allocate_tensors()# 解释器的输入和输出是由索引号来定位的self.input_idx=self.interpreter.get_input_details()[0]['index']self.output_idx=self.interpreter.get_output_details()[0]['index']def compute(self,input_data):"""模型计算,输入单条数据"""self.interpreter.set_tensor(self.input_idx,input_data)self.interpreter.invoke()res=self.interpreter.get_tensor(self.output_idx)res=np.squeeze(res)return resdef getResult(self,datas):"""返回所有的计算结果"""self.result=np.array([self.compute(data) for data in datas])@staticmethoddef resource_path(relative_path):if getattr(sys, 'frozen', False): #是否Bundle Resourcebase_path = sys._MEIPASSelse:base_path = os.path.abspath(".")return os.path.join(base_path, relative_path)
