tensorflow_lite模型加载

interpreter = tflite_runtime.interpreter.Interpreter(model_path)

加载指定路径上的tflite模型

参数说明

model_path:模型路径

返回值

interpreter:表示对应tflite模型的解释器

分配内存

Interpreter.allocate_tensors()

为模型中的张量分配好内存

参数说明

返回值

设置模型输入数据

Interpreter.set_tensor(input_index, input_data)

将对应输入层的张量设置为指定数据

参数说明

input_index:对应输入层的索引
input_data:输入数据,需要和输入层相匹配

返回值

模型推理

Interpreter.invoke()

执行模型推理

参数说明

返回值

获取模型输出数据

data = Interpreter.get_tensor(output_index)

获取对应输出层的内容

参数说明

output_index:对应输出层的索引

返回值

对应输出层的数据

使用示例:打开摄像头,执行目标检测

  1. #!coding: utf-8
  2. import os
  3. import cv2
  4. from tflite_runtime.interpreter import Interpreter
  5. import numpy as np
  6. Key_Esc = 27
  7. class ObjectDetector:
  8. def __init__(self):
  9. self.min_conf_threshold = 0.5
  10. def load_model(self):
  11. MODEL_PATH = os.path.expanduser(
  12. '~')+'/Lepi_Data/ros/object_detector/MobileDet_SSD'
  13. GRAPH_NAME = 'model.tflite'
  14. LABELMAP_NAME = 'labelmap.txt'
  15. # Path to .tflite file, which contains the model that is used for object detection
  16. PATH_TO_CKPT = os.path.join(MODEL_PATH, GRAPH_NAME)
  17. # Path to label map file
  18. PATH_TO_LABELS = os.path.join(MODEL_PATH, LABELMAP_NAME)
  19. self.interpreter = Interpreter(model_path=PATH_TO_CKPT)
  20. self.interpreter.allocate_tensors()
  21. # Load the label map
  22. with open(PATH_TO_LABELS, 'r') as f:
  23. self.labels = [line.strip() for line in f.readlines()]
  24. # Have to do a weird fix for label map if using the COCO "starter model" from
  25. # https://www.tensorflow.org/lite/models/object_detection/overview
  26. # First label is '???', which has to be removed.
  27. if self.labels[0] == '???':
  28. del(self.labels[0])
  29. # Get model details
  30. self.input_details = self.interpreter.get_input_details()
  31. self.output_details = self.interpreter.get_output_details()
  32. self.height = self.input_details[0]['shape'][1]
  33. self.width = self.input_details[0]['shape'][2]
  34. self.floating_model = (self.input_details[0]['dtype'] == np.float32)
  35. # Loop over every image and perform detection
  36. def detect(self, image):
  37. # Load image and resize to expected shape [1xHxWx3]
  38. # image = cv2.imread(image_path)
  39. image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  40. image_resized = cv2.resize(image_rgb, (self.width, self.height))
  41. input_data = np.expand_dims(image_resized, axis=0)
  42. # Normalize pixel values if using a floating model (i.e. if model is non-quantized)
  43. if self.floating_model:
  44. input_mean = 127.5
  45. input_std = 127.5
  46. input_data = (np.float32(input_data) - input_mean) / input_std
  47. # Perform the actual detection by running the model with the image as input
  48. self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
  49. self.interpreter.invoke()
  50. # Retrieve detection results
  51. boxes = self.interpreter.get_tensor(self.output_details[0]['index'])[
  52. 0] # Bounding box coordinates of detected objects
  53. classes = self.interpreter.get_tensor(self.output_details[1]['index'])[
  54. 0] # Class index of detected objects
  55. scores = self.interpreter.get_tensor(self.output_details[2]['index'])[
  56. 0] # Confidence of detected objects
  57. # num = interpreter.get_tensor(output_details[3]['index'])[0] # Total number of detected objects (inaccurate and not needed)
  58. return boxes, classes, scores
  59. def set_threshold(self, threshold):
  60. if threshold < 100:
  61. self.min_conf_threshold = threshold/100.0
  62. def draw_labels(self, image, boxes, classes, scores):
  63. imH, imW, _ = image.shape
  64. # Loop over all detections and draw detection box if confidence is above minimum threshold
  65. for i in range(len(scores)):
  66. if ((scores[i] > self.min_conf_threshold) and (scores[i] <= 1.0)):
  67. # Get bounding box coordinates and draw box
  68. # Interpreter can return coordinates that are outside of image dimensions, need to force them to be within image using max() and min()
  69. ymin, xmin, ymax, xmax = self.getRealBox(boxes[i], (imW, imH))
  70. cv2.rectangle(image, (xmin, ymin),
  71. (xmax, ymax), (10, 255, 0), 4)
  72. # Draw label
  73. # Look up object name from "labels" array using class index
  74. object_name = self.labels[int(classes[i])]
  75. label = '%s: %d%%' % (object_name, int(
  76. scores[i]*100)) # Example: 'person: 72%'
  77. image = cv2.putText(image, label, (xmin+10, ymin+25), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
  78. fontScale=2, color=(0, 0, 255), thickness=4) # Draw label text
  79. return image
  80. def getRealBox(self, box, size=(480, 360)):
  81. imW, imH = size
  82. ymin = int(max(1, (box[0] * imH)))
  83. xmin = int(max(1, (box[1] * imW)))
  84. ymax = int(min(imH, (box[2] * imH)))
  85. xmax = int(min(imW, (box[3] * imW)))
  86. return [ymin, xmin, ymax, xmax]
  87. if __name__ == '__main__':
  88. import sys
  89. detector = ObjectDetector()
  90. detector.load_model()
  91. cap = cv2.VideoCapture(0)
  92. while True:
  93. ret, image = cap.read()
  94. boxes, classes, scores = detector.detect(image)
  95. image = detector.draw_labels(image, boxes, classes, scores)
  96. cv2.imshow('Object detector', np.rot90(cv2.resize(image, (320, 240))))
  97. # 按Esc退出
  98. if cv2.waitKey(1) == Key_Esc:
  99. break
  100. # Clean up
  101. cap.release()
  102. cv2.destroyAllWindows()