RUN ENV

  • tf-1.12.3: import tensorflow as tf
  • tf-1.15.2: import tensorflow as tf
  • tf-2.1.0: import tensorflow.compat.v1 as tf

Functions

  • 加载ckpt模型
  • ckpt->SavedModel
  • ckpt->Frozen Grapha
  1. with tf.Graph.as_default(), tf.Session() as sess:
  2. input_shape = [1, height, width, 3]
  3. inputs = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape)
  4. outputs = build_model(inputs, is_training=False)
  5. checkpoint = tf.train.latest_checkpoint(ckpt_path)
  6. ## TODO
  7. checkpoint = 'xxx'
  8. print('Loading checkpoint: %s', checkpoint)
  9. saver.restore(sess, checkpoint)
  10. output_node_names = [node.name.split(':')[0] for node in outputs]
  11. ## export SavedModel
  12. saved_model_name = './tmp/saved_model'
  13. tf.saved_model.simple_save(sess, saved_model_name,
  14. inputs={'input': inputs},
  15. outputs={output_node_names[0]: outputs[0],
  16. output_node_names[1]: outputs[1],
  17. output_node_names[2]: outputs[2],
  18. output_node_names[3]: outputs[3],
  19. output_node_names[4]: outputs[4],
  20. output_node_names[5]: outputs[5],
  21. output_node_names[6]: outputs[6],
  22. output_node_names[7]: outputs[7],
  23. output_node_names[8]: outputs[8],
  24. output_node_names[9]: outputs[9]})
  25. ## export frozen graph
  26. pbtxt_name = './tmp/pb_model/efficientdet_d0.pbtxt'
  27. pb_model_name = './tmp/pb_model/efficientdet_d0.pb'
  28. tf.train.write_graph(sess.graph_def, '', pbtxt_name, as_text=True)
  29. input_graph_def = tf.get_default_graph().as_graph_def()
  30. output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
  31. sess.graph_def,
  32. # input_graph_def,
  33. output_node_names)
  34. with tf.io.gfile.GFile(pb_model_name, "wb") as f:
  35. f.write(output_graph_def.SerializeToString())

RUN ENV

  • tf-1.12.3: import tensorflow as tf
  • tf-1.15.2: import tensorflow as tf
  • tf-2.1.0: import tensorflow.compat.v1 as tf

Functions

  • SavedModel->float32 tflite
  • SavedModel->int8 tflite
  1. # 以来模型文件是基于哪个环境构建的,高版本中有些ops低版本不支持
  2. import sys
  3. if sys.version_info.major >= 3:
  4. import pathlib
  5. else:
  6. import pathlib2 as pathlib
  7. tflite_model_dir = pathlib.Path('./tmp/tflite/efficentdet_d0/')
  8. tflite_model_dir.mkdir(exist_ok=True, parents=True)
  9. tflite_model_name = 'efficentdet_d0_flot32.tflite'
  10. tflite_model_file = tflite_model_dir/tflite_model_name
  11. converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
  12. converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
  13. tflite_model = converter.convert()
  14. tflite_model_file.write_bytes(tflite_model)
  15. ## export tflite: SavedModel to tflite_int8
  16. # tf_1.15.2: import tensorflow as tf
  17. # tf_2.1: import tensorflow as tf
  18. # tf_2.1: import tensorflow.compat.v1 as tf
  19. # 以来模型文件是基于哪个环境构建的,高版本中有些ops低版本不支持
  20. import sys
  21. if sys.version_info.major >= 3:
  22. import pathlib
  23. else:
  24. import pathlib2 as pathlib
  25. tflite_model_dir = pathlib.Path('./tmp/tflite/efficentdet_d0/')
  26. tflite_model_dir.mkdir(exist_ok=True, parents=True)
  27. tflite_model_name = 'efficentdet_d0_int8.tflite'
  28. tflite_model_file = tflite_model_dir/tflite_model_name
  29. converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
  30. data_path = './sample_imgs/*'
  31. data = glob.glob(data_path)
  32. print('Convert using full integer quantization ...')
  33. print("Calibration data size : ", len(data))
  34. def aspect_preserving_resize(img, _network_w, _network_h):
  35. orig_height, orig_width, _ = img.shape
  36. if orig_width > orig_height:
  37. resize_ratio = 1.0 * _network_w / orig_width
  38. target_size = (_network_w, int(resize_ratio * orig_height))
  39. else:
  40. resize_ratio = 1.0 * _network_h / orig_height
  41. target_size = (int(resize_ratio * orig_width), _network_h)
  42. img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
  43. ## Pad image and label to have dimensions >= [crop_height, crop_width]
  44. image_height = target_size[1]
  45. image_width = target_size[0]
  46. bottom = max(_network_h - image_height, 0)
  47. right = max(_network_w - image_width, 0)
  48. # Pad image with mean pixel value.
  49. color = [127.5,127.5,127.5]#[0.0,0.0,0.0]#
  50. img = cv2.copyMakeBorder(img, 0, bottom, 0, right, cv2.BORDER_CONSTANT,value=color)
  51. return img
  52. def representative_data_gen():
  53. for idx in range(len(data)):
  54. img = cv2.cvtColor(cv2.imread(data[idx]), cv2.COLOR_BGR2RGB)
  55. img = aspect_preserving_resize(img, 512, 512)
  56. img = np.float32(img - 127.5) * 2.0 / 255.0
  57. input_value = np.expand_dims(img, axis=0)
  58. yield [input_value]
  59. image_size = [512, 512]
  60. def representative_dataset_gen():
  61. for idx in range(len(data)):
  62. image = cv2.imread(data[idx])
  63. image_ori = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  64. image = np.float32((image_ori /255.0 - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225])
  65. _crop_offset_y = 0
  66. _crop_offset_x = 0
  67. height, width = image.shape[0], image.shape[1]
  68. image_scale_y = 512 * 1.0 / height
  69. image_scale_x = 512 * 1.0 / width
  70. image_scale = np.minimum(image_scale_x, image_scale_y)
  71. scaled_height = int(height * image_scale)
  72. scaled_width = int(width * image_scale)
  73. scaled_image = cv2.resize(image, (scaled_width, scaled_height), cv2.INTER_LINEAR)
  74. padding_value = (0.0, 0.0, 0.0)
  75. bottom = max(512 - scaled_height, 0)
  76. right = max(512 - scaled_width, 0)
  77. padding_img = cv2.copyMakeBorder(scaled_image,
  78. 0, bottom, 0, right,
  79. cv2.BORDER_CONSTANT, value=padding_value)
  80. # padding_img = np.float32(padding_img / 255.0)
  81. # padding_img = np.float32((padding_img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225])
  82. input_value = np.expand_dims(padding_img, 0)
  83. yield [input_value]
  84. converter.representative_dataset = representative_dataset_gen
  85. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  86. # converter.inference_input_type = tf.uint8
  87. # converter.inference_output_type = tf.uint8
  88. # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  89. tflite_model = converter.convert()
  90. tflite_model_file.write_bytes(tflite_model)
  91. ## export tflite: pb model to tflite_flot32
  92. # tf_1.15.2: import tensorflow as tf
  93. # tf_2.1: import tensorflow as tf
  94. # tf_2.1: import tensorflow.compat.v1 as tf
  95. import tensorflow.compat.v1 as tf
  96. pb_model_name = './tmp/pb_model/efficientdet_d0.pb'
  97. tflite_model_name = './tmp/tflite/efficientdet_d0_flot32.tflite'
  98. inputs=["input"]
  99. classes=["class_net/class-predict/BiasAdd", 'class_net/class-predict_1/BiasAdd',
  100. 'class_net/class-predict_2/BiasAdd', 'class_net/class-predict_3/BiasAdd',
  101. 'class_net/class-predict_4/BiasAdd', 'box_net/box-predict/BiasAdd',
  102. 'box_net/box-predict_1/BiasAdd', 'box_net/box-predict_2/BiasAdd',
  103. 'box_net/box-predict_3/BiasAdd', 'box_net/box-predict_4/BiasAdd'] #模型文件的输出节点名称
  104. converter = tf.contrib.lite.TocoConverter.from_frozen_graph(pb_model_name,
  105. inputs,
  106. classes)
  107. tflite_model=converter.convert()
  108. open(tflite_model_name, 'wb').write(tflite_model)