tensorflow_lite模型加载
interpreter = tflite_runtime.interpreter.Interpreter(model_path)
参数说明
返回值
interpreter:表示对应tflite模型的解释器
分配内存
Interpreter.allocate_tensors()
参数说明
返回值
无
设置模型输入数据
Interpreter.set_tensor(input_index, input_data)
参数说明
input_index:对应输入层的索引
input_data:输入数据,需要和输入层相匹配
返回值
无
模型推理
Interpreter.invoke()
参数说明
返回值
无
获取模型输出数据
data = Interpreter.get_tensor(output_index)
参数说明
返回值
对应输出层的数据
使用示例:打开摄像头,执行目标检测
#!coding: utf-8import osimport cv2from tflite_runtime.interpreter import Interpreterimport numpy as npKey_Esc = 27class ObjectDetector:def __init__(self):self.min_conf_threshold = 0.5def load_model(self):MODEL_PATH = os.path.expanduser('~')+'/Lepi_Data/ros/object_detector/MobileDet_SSD'GRAPH_NAME = 'model.tflite'LABELMAP_NAME = 'labelmap.txt'# Path to .tflite file, which contains the model that is used for object detectionPATH_TO_CKPT = os.path.join(MODEL_PATH, GRAPH_NAME)# Path to label map filePATH_TO_LABELS = os.path.join(MODEL_PATH, LABELMAP_NAME)self.interpreter = Interpreter(model_path=PATH_TO_CKPT)self.interpreter.allocate_tensors()# Load the label mapwith open(PATH_TO_LABELS, 'r') as f:self.labels = [line.strip() for line in f.readlines()]# Have to do a weird fix for label map if using the COCO "starter model" from# https://www.tensorflow.org/lite/models/object_detection/overview# First label is '???', which has to be removed.if self.labels[0] == '???':del(self.labels[0])# Get model detailsself.input_details = self.interpreter.get_input_details()self.output_details = self.interpreter.get_output_details()self.height = self.input_details[0]['shape'][1]self.width = self.input_details[0]['shape'][2]self.floating_model = (self.input_details[0]['dtype'] == np.float32)# Loop over every image and perform detectiondef detect(self, image):# Load image and resize to expected shape [1xHxWx3]# image = cv2.imread(image_path)image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)image_resized = cv2.resize(image_rgb, (self.width, self.height))input_data = np.expand_dims(image_resized, axis=0)# Normalize pixel values if using a floating model (i.e. if model is non-quantized)if self.floating_model:input_mean = 127.5input_std = 127.5input_data = (np.float32(input_data) - input_mean) / input_std# Perform the actual detection by running the model with the image as inputself.interpreter.set_tensor(self.input_details[0]['index'], input_data)self.interpreter.invoke()# Retrieve detection resultsboxes = self.interpreter.get_tensor(self.output_details[0]['index'])[0] # Bounding box coordinates of detected objectsclasses = self.interpreter.get_tensor(self.output_details[1]['index'])[0] # Class index of detected objectsscores = self.interpreter.get_tensor(self.output_details[2]['index'])[0] # Confidence of detected objects# num = interpreter.get_tensor(output_details[3]['index'])[0] # Total number of detected objects (inaccurate and not needed)return boxes, classes, scoresdef set_threshold(self, threshold):if threshold < 100:self.min_conf_threshold = threshold/100.0def draw_labels(self, image, boxes, classes, scores):imH, imW, _ = image.shape# Loop over all detections and draw detection box if confidence is above minimum thresholdfor i in range(len(scores)):if ((scores[i] > self.min_conf_threshold) and (scores[i] <= 1.0)):# Get bounding box coordinates and draw box# Interpreter can return coordinates that are outside of image dimensions, need to force them to be within image using max() and min()ymin, xmin, ymax, xmax = self.getRealBox(boxes[i], (imW, imH))cv2.rectangle(image, (xmin, ymin),(xmax, ymax), (10, 255, 0), 4)# Draw label# Look up object name from "labels" array using class indexobject_name = self.labels[int(classes[i])]label = '%s: %d%%' % (object_name, int(scores[i]*100)) # Example: 'person: 72%'image = cv2.putText(image, label, (xmin+10, ymin+25), fontFace=cv2.FONT_HERSHEY_SIMPLEX,fontScale=2, color=(0, 0, 255), thickness=4) # Draw label textreturn imagedef getRealBox(self, box, size=(480, 360)):imW, imH = sizeymin = int(max(1, (box[0] * imH)))xmin = int(max(1, (box[1] * imW)))ymax = int(min(imH, (box[2] * imH)))xmax = int(min(imW, (box[3] * imW)))return [ymin, xmin, ymax, xmax]if __name__ == '__main__':import sysdetector = ObjectDetector()detector.load_model()cap = cv2.VideoCapture(0)while True:ret, image = cap.read()boxes, classes, scores = detector.detect(image)image = detector.draw_labels(image, boxes, classes, scores)cv2.imshow('Object detector', np.rot90(cv2.resize(image, (320, 240))))# 按Esc退出if cv2.waitKey(1) == Key_Esc:break# Clean upcap.release()cv2.destroyAllWindows()
