该项目的evaluate文件夹下的一个脚本eval_metric.py定义了关于测试过程中的评价函数。这个脚本主要涉及两个类:MApMetric和VOC07MApMetric,后者是继承前者并重写了一些方法得到的,因此MApMetric类是核心。这两者都是用来计算object detection算法中的MAp(Mean avearage precision)。

    1. import mxnet as mx
    2. import numpy as np
    3. import os
    4. import matplotlib
    5. matplotlib.use('Agg')
    6. import matplotlib.pyplot as plt
    7. class MApMetric(mx.metric.EvalMetric):
    8. """
    9. Calculate mean AP for object detection task
    10. Parameters:
    11. ---------
    12. ovp_thresh : float
    13. overlap threshold for TP
    14. use_difficult : boolean
    15. use difficult ground-truths if applicable, otherwise just ignore
    16. class_names : list of str
    17. optional, if provided, will print out AP for each class
    18. pred_idx : int
    19. prediction index in network output list
    20. roc_output_path
    21. optional, if provided, will save a ROC graph for each class
    22. tensorboard_path
    23. optional, if provided, will save a ROC graph to tensorboard
    24. """
    25. # __init__中还是执行常规的重置操作:reset()和一些赋值操作。
    26. def __init__(self, ovp_thresh=0.5, use_difficult=False, class_names=None,
    27. pred_idx=0, roc_output_path=None, tensorboard_path=None):
    28. super(MApMetric, self).__init__('mAP')
    29. if class_names is None:
    30. self.num = None
    31. else:
    32. assert isinstance(class_names, (list, tuple))
    33. for name in class_names:
    34. assert isinstance(name, str), "must provide names as str"
    35. num = len(class_names)
    36. self.name = class_names + ['mAP']
    37. self.num = num + 1
    38. self.reset()
    39. self.ovp_thresh = ovp_thresh
    40. self.use_difficult = use_difficult
    41. self.class_names = class_names
    42. self.pred_idx = int(pred_idx)
    43. self.roc_output_path = roc_output_path
    44. self.tensorboard_path = tensorboard_path
    45. def save_roc_graph(self, recall=None, prec=None, classkey=1, path=None, ap=None):
    46. if not os.path.exists(path):
    47. os.mkdir(path)
    48. plot_path = os.path.join(path, 'roc_'+self.class_names[classkey])
    49. if os.path.exists(plot_path):
    50. os.remove(plot_path)
    51. fig = plt.figure()
    52. plt.title(self.class_names[classkey])
    53. plt.plot(recall, prec, 'b', label='AP = %0.2f' % ap)
    54. plt.legend(loc='lower right')
    55. plt.xlim([0, 1])
    56. plt.ylim([0, 1])
    57. plt.ylabel('Precision')
    58. plt.xlabel('Recall')
    59. plt.savefig(plot_path)
    60. plt.close(fig)
    61. def reset(self):
    62. """Clear the internal statistics to initial state."""
    63. if getattr(self, 'num', None) is None:
    64. self.num_inst = 0
    65. self.sum_metric = 0.0
    66. else:
    67. self.num_inst = [0] * self.num
    68. self.sum_metric = [0.0] * self.num
    69. self.records = dict()
    70. self.counts = dict()
    71. # 当代码要读取MAp值时就会调用get方法,在get方法中通过调用_update方法计算self.records变量得到MAp值。
    72. # 因为self.sum_metric和self.num_inst在这里是list,所以通过循环读取的方式最后返回tuple。
    73. def get(self):
    74. """Get the current evaluation result.
    75. Returns
    76. -------
    77. name : str
    78. Name of the metric.
    79. value : float
    80. Value of the evaluation.
    81. """
    82. self._update() # update metric at this time
    83. if self.num is None:
    84. if self.num_inst == 0:
    85. return (self.name, float('nan'))
    86. else:
    87. return (self.name, self.sum_metric / self.num_inst)
    88. else:
    89. names = ['%s'%(self.name[i]) for i in range(self.num)]
    90. values = [x / y if y != 0 else float('nan') \
    91. for x, y in zip(self.sum_metric, self.num_inst)]
    92. return (names, values)
    93. # update方法是更新MAp值的方法,目的是更新self.records变量。然后当代码要读取MAp值时就会调用get方法,
    94. # 在get方法中通过再调用_update方法计算self.records变量得到MAp值。
    95. def update(self, labels, preds):
    96. """
    97. Update internal records. This function now only update internal buffer,
    98. sum_metric and num_inst are updated in _update() function instead when
    99. get() is called to return results.
    100. Params:
    101. ----------
    102. labels: mx.nd.array (n * 6) or (n * 5), difficult column is optional
    103. 2-d array of ground-truths, n objects(id-xmin-ymin-xmax-ymax-[difficult])
    104. preds: mx.nd.array (m * 6)
    105. 2-d array of detections, m objects(id-score-xmin-ymin-xmax-ymax)
    106. """
    107. # IOU计算函数,就是计算两个框的交集面积除以并集面积的结果
    108. def iou(x, ys):
    109. """
    110. Calculate intersection-over-union overlap
    111. Params:
    112. ----------
    113. x : numpy.array
    114. single box [xmin, ymin ,xmax, ymax]
    115. ys : numpy.array
    116. multiple box [[xmin, ymin, xmax, ymax], [...], ]
    117. Returns:
    118. -----------
    119. numpy.array
    120. [iou1, iou2, ...], size == ys.shape[0]
    121. """
    122. ixmin = np.maximum(ys[:, 0], x[0])
    123. iymin = np.maximum(ys[:, 1], x[1])
    124. ixmax = np.minimum(ys[:, 2], x[2])
    125. iymax = np.minimum(ys[:, 3], x[3])
    126. iw = np.maximum(ixmax - ixmin, 0.)
    127. ih = np.maximum(iymax - iymin, 0.)
    128. inters = iw * ih
    129. uni = (x[2] - x[0]) * (x[3] - x[1]) + (ys[:, 2] - ys[:, 0]) * \
    130. (ys[:, 3] - ys[:, 1]) - inters
    131. ious = inters / uni
    132. ious[uni < 1e-12] = 0 # in case bad boxes
    133. return ious
    134. # independant execution for each image
    135. # labels变量放的是batch size个图像的N个object的类别和坐标信息(非object的类别用-1表示),
    136. # preds则是网络的输出(包含4个,这里取最后一个得到batch size个图像的M个anchor的预测类别、置信度和坐标信息)。
    137. # 这个大的for循环就是循环batch中的每张图像。
    138. for i in range(labels[0].shape[0]):
    139. # get as numpy arrays
    140. label = labels[0][i].asnumpy()
    141. pred = preds[self.pred_idx][i].asnumpy()
    142. # calculate for each class
    143. while (pred.shape[0] > 0):
    144. # 每次循环都去pred(二维)的第一行的第一列,该值是第一个anchor的预测类别,后面会把属于该类别的预测值都copy到别的变量,
    145. # 然后将pred中该类别的预测值都删掉,所以每次循环时pred[0,0]的值都会变化,变化的次数就是你的类别数
    146. cid = int(pred[0, 0])
    147. indices = np.where(pred[:, 0].astype(int) == cid)[0]
    148. # 如果是背景类别,则从pred变量中删除
    149. if cid < 0:
    150. pred = np.delete(pred, indices, axis=0)
    151. continue
    152. # 将属于该预测类别的预测值copy给dets,然后从pred中删除该预测类别的预测值
    153. dets = pred[indices]
    154. pred = np.delete(pred, indices, axis=0)
    155. # sort by score, desceding
    156. # 按照置信度从高到低进行排序,records的第二列用来记录每个预测值的tp(truth positive)和fp(false positive)值,
    157. # 分别用1和2表示,初始化为0。
    158. dets[dets[:,1].argsort()[::-1]]
    159. records = np.hstack((dets[:, 1][:, np.newaxis], np.zeros((dets.shape[0], 1))))
    160. # ground-truths
    161. # label_indices是输入的该图像中object类别等于前面预测的cid类别的object index,并将这些object的类别和位置信息保存在gts变量中
    162. label_indices = np.where(label[:, 0].astype(int) == cid)[0]
    163. gts = label[label_indices, :]
    164. label = np.delete(label, label_indices, axis=0)
    165. # 如果真实的object类别和预测的cid类别有交集,则gts.size>0,否则跳过这个条件语句。
    166. if gts.size > 0:
    167. found = [False] * gts.shape[0]
    168. # 这个循环条件是遍历预测的类别值为cid的anchor,对每个anchor都计算其和真实的类别为cid的object框的IOU值。
    169. # 取其中最大的IOU值赋给ovmax
    170. for j in range(dets.shape[0]):
    171. # compute overlaps
    172. ious = iou(dets[j, 2:], gts[:, 1:5])
    173. ovargmax = np.argmax(ious)
    174. ovmax = ious[ovargmax]
    175. # 当IOU大于ovp_thresh时候,因为gts.shape[1]==5,所以执行 records[j, -1] = 1
    176. # 和found[ovargmax] = True。如果IOU没有达到这个阈值,则还是false positive。
    177. if ovmax > self.ovp_thresh:
    178. if (not self.use_difficult and
    179. gts.shape[1] >= 6 and
    180. gts[ovargmax, 5] > 0):
    181. pass
    182. else:
    183. if not found[ovargmax]:
    184. records[j, -1] = 1 # tp
    185. found[ovargmax] = True
    186. else:
    187. # duplicate
    188. records[j, -1] = 2 # fp
    189. else:
    190. # 这里相当于预测的类别在图像的所有object类别中都不存在,所以都是false positive
    191. records[j, -1] = 2 # fp
    192. else:
    193. # no gt, mark all fp
    194. records[:, -1] = 2
    195. # ground truth count
    196. if (not self.use_difficult and gts.shape[1] >= 6):
    197. gt_count = np.sum(gts[:, 5] < 1)
    198. else:
    199. gt_count = gts.shape[0]
    200. # now we push records to buffer
    201. # first column: score, second column: tp/fp
    202. # 0: not set(matched to difficult or something), 1: tp, 2: fp
    203. # 过滤掉records中既不是fp也不是tp的预测值,然后将符合条件的records通过_insert方法插入到self.records,
    204. # 最后得到的self.records就是整个batch的总结果。
    205. records = records[np.where(records[:, -1] > 0)[0], :]
    206. if records.size > 0:
    207. self._insert(cid, records, gt_count)
    208. # add missing class if not present in prediction
    209. while (label.shape[0] > 0):
    210. cid = int(label[0, 0])
    211. label_indices = np.where(label[:, 0].astype(int) == cid)[0]
    212. label = np.delete(label, label_indices, axis=0)
    213. if cid < 0:
    214. continue
    215. gt_count = label_indices.size
    216. self._insert(cid, np.array([[0, 0]]), gt_count)
    217. #_update方法是作者自定义的一个内部方法,用来帮助算法在调用get方法的时候获取所需的计算值,要注意和update方法的差别。
    218. # 该方法基于前面update方法计算得到的sel.records来计算ap,self.records是一个包含number class个键值对的字典。
    219. # recall, prec = self._recall_prec(v, self.counts[k])是计算recall和precision,
    220. # ap = self._average_precision(recall, prec)是计算平均的recall和precision。
    221. def _update(self):
    222. """ update num_inst and sum_metric """
    223. aps = []
    224. for k, v in self.records.items():
    225. recall, prec = self._recall_prec(v, self.counts[k])
    226. ap = self._average_precision(recall, prec)
    227. if self.roc_output_path is not None:
    228. self.save_roc_graph(recall=recall, prec=prec, classkey=k, path=self.roc_output_path, ap=ap)
    229. aps.append(ap)
    230. # 因为k值是遍历所有object的类别,所以这里self.sum_metric[k]放的就是k这个类别的ap值。
    231. # 因此最后在界面上会显示每个类别的MAp值。
    232. if self.num is not None and k < (self.num - 1):
    233. self.sum_metric[k] = ap
    234. self.num_inst[k] = 1
    235. if self.num is None:
    236. self.num_inst = 1
    237. self.sum_metric = np.mean(aps)
    238. # 在sum_metric和self.num_inst的最后位置插入平均结果,所以在界面上会显示所有类别的平均MAp值。
    239. else:
    240. self.num_inst[-1] = 1
    241. self.sum_metric[-1] = np.mean(aps)
    242. #_recall_prec方法是前面_update方法调用的一个辅助方法。
    243. def _recall_prec(self, record, count):
    244. """ get recall and precision from internal records """
    245. record = np.delete(record, np.where(record[:, 1].astype(int) == 0)[0], axis=0)
    246. sorted_records = record[record[:,0].argsort()[::-1]]
    247. tp = np.cumsum(sorted_records[:, 1].astype(int) == 1)
    248. fp = np.cumsum(sorted_records[:, 1].astype(int) == 2)
    249. if count <= 0:
    250. recall = tp * 0.0
    251. else:
    252. recall = tp / float(count)
    253. prec = tp.astype(float) / (tp + fp)
    254. return recall, prec
    255. def _average_precision(self, rec, prec):
    256. """
    257. calculate average precision
    258. Params:
    259. ----------
    260. rec : numpy.array
    261. cumulated recall
    262. prec : numpy.array
    263. cumulated precision
    264. Returns:
    265. ----------
    266. ap as float
    267. """
    268. # append sentinel values at both ends
    269. mrec = np.concatenate(([0.], rec, [1.]))
    270. mpre = np.concatenate(([0.], prec, [0.]))
    271. # compute precision integration ladder
    272. for i in range(mpre.size - 1, 0, -1):
    273. mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
    274. # look for recall value changes
    275. i = np.where(mrec[1:] != mrec[:-1])[0]
    276. # sum (\delta recall) * prec
    277. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    278. return ap
    279. # 将有效的records插入到self.records变量中
    280. def _insert(self, key, records, count):
    281. """ Insert records according to key """
    282. if key not in self.records:
    283. assert key not in self.counts
    284. self.records[key] = records
    285. self.counts[key] = count
    286. else:
    287. self.records[key] = np.vstack((self.records[key], records))
    288. assert key in self.counts
    289. self.counts[key] += count
    290. class VOC07MApMetric(MApMetric):
    291. """ Mean average precision metric for PASCAL V0C 07 dataset """
    292. def __init__(self, *args, **kwargs):
    293. super(VOC07MApMetric, self).__init__(*args, **kwargs)
    294. def _average_precision(self, rec, prec):
    295. """
    296. calculate average precision, override the default one,
    297. special 11-point metric
    298. Params:
    299. ----------
    300. rec : numpy.array
    301. cumulated recall
    302. prec : numpy.array
    303. cumulated precision
    304. Returns:
    305. ----------
    306. ap as float
    307. """
    308. ap = 0.
    309. for t in np.arange(0., 1.1, 0.1):
    310. if np.sum(rec >= t) == 0:
    311. p = 0
    312. else:
    313. p = np.max(prec[rec >= t])
    314. ap += p / 11.
    315. return ap