TFlite模型变换及使用

1. tflite模型转换

keras 模型转换

  1. import tensorflow as tf
  2. converter = tf.lite.TFLiteConverter.from_keras_model_file("keras_model.h5")
  3. tflite_model = converter.convert()
  4. open("converted_model.tflite", "wb").write(tflite_model)

SavedModel

  1. import tensorflow as tf
  2. converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
  3. tflite_model = converter.convert()
  4. open("converted_model.tflite", "wb").write(tflite_model)

PB:需要将meta和checkpoint通过freeze_graph输出为pb格式,再进行转化

  1. import tensorflow as tf
  2. graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb"
  3. input_arrays = ["input"]
  4. output_arrays = ["MobilenetV1/Predictions/Softmax"]
  5. converter = tf.lite.TFLiteConverter.from_frozen_graph(
  6. graph_def_file, input_arrays, output_arrays)
  7. tflite_model = converter.convert()
  8. open("converted_model.tflite", "wb").write(tflite_model)

2. tflite轻量化模型推理

导入包:

  1. import tflite_runtime.interpreter as tflite

这里使用tflite_runtime 库来实现模型的推理,该库独立与tensorflow可以单独运行,截至20210826官网上提供包在mac和linux上支持到py3.8,而win10下仅支持到python3.7。此外要成功打包需要对interpreter.py中函数进行修改。

  1. if not __file__.endswith(os.path.join('tflite_runtime', 'interpreter.py')):
  2. # This file is part of tensorflow package.
  3. # from tensorflow.python.util.lazy_loader import LazyLoader
  4. # from tensorflow.python.util.tf_export import tf_export as _tf_export
  5. # Lazy load since some of the performance benchmark skylark rules
  6. # break dependencies. Must use double quotes to match code internal rewrite
  7. # rule.
  8. # pylint: disable=g-inconsistent-quotes
  9. # _interpreter_wrapper = LazyLoader(
  10. # "_interpreter_wrapper", globals(),
  11. # "tensorflow.lite.python.interpreter_wrapper."
  12. # "tensorflow_wrap_interpreter_wrapper")
  13. # pylint: enable=g-inconsistent-quotes

这里调用了tensorflow函数在程序打包后会提示找不到tensorflow依赖,修改后如下:

  1. from tflite_runtime import tensorflow_wrap_interpreter_wrapper as _interpreter_wrapper
  2. def _tf_export(*x, **kwargs):
  3. del x, kwargs
  4. return lambda x: x

本质上是将else中的部分在if下复制一份,让这两个导入的函数均来自tflite_runtime而不是tensorflow。

构建模型类如下:

  1. class ModelProcess:
  2. def __init__(self,model_path=None) -> None:
  3. self.interpreter=None
  4. self.result=None
  5. self.__preProcess(model_path)
  6. def __preProcess(self,path=None):
  7. if path==None:
  8. # 这里的操作是为了在将模型打包进程序后,能够从系统临时目录中加载模型,详细说明见pyinstaller --add-data
  9. filename = self.resource_path(os.path.join("model_dir","full_model.tflite"))
  10. self.interpreter=tflite.Interpreter(model_path=filename)
  11. else:
  12. self.interpreter=tflite.Interpreter(model_path=path)
  13. self.interpreter.allocate_tensors()
  14. # 解释器的输入和输出是由索引号来定位的
  15. self.input_idx=self.interpreter.get_input_details()[0]['index']
  16. self.output_idx=self.interpreter.get_output_details()[0]['index']
  17. def compute(self,input_data):
  18. """
  19. 模型计算,输入单条数据
  20. """
  21. self.interpreter.set_tensor(self.input_idx,input_data)
  22. self.interpreter.invoke()
  23. res=self.interpreter.get_tensor(self.output_idx)
  24. res=np.squeeze(res)
  25. return res
  26. def getResult(self,datas):
  27. """
  28. 返回所有的计算结果
  29. """
  30. self.result=np.array([self.compute(data) for data in datas])
  31. @staticmethod
  32. def resource_path(relative_path):
  33. if getattr(sys, 'frozen', False): #是否Bundle Resource
  34. base_path = sys._MEIPASS
  35. else:
  36. base_path = os.path.abspath(".")
  37. return os.path.join(base_path, relative_path)