
预测过程中的detec_imaget函数,函数如下面三图
#保留原始图片信息
old_width = image_shape[1]
old_height = image_shape[0]
old_image = copy.deepcopy(image)
#对图片修剪,并进行归一化
width,height = get_new_img_size(old_width,old_height)
image = image.resize([width,height])
photo = np.array(image,dtype = np.float32)/255photo = np.transpose(photo, (2, 0, 1))
#把图片传入到模型中进行预测
images = []
images.append(photo)
images = np.asarray(images)
images = torch.from_numpy(images).cuda()


在该函数中有调用 DecodeBox()函数该函数如下两图
#防止框超出边缘
_# clip bounding box_cls_bbox[…, 0] = (cls_bbox[…, 0]).clamp(min=0, max=width)
cls_bbox[…, 2] = (cls_bbox[…, 2]).clamp(min=0, max=width)
cls_bbox[…, 1] = (cls_bbox[…, 1]).clamp(min=0, max=height)
cls_bbox[…, 3] = (cls_bbox[…, 3]).clamp(min=0, max=height)
#计算每一个类的概率(包含背景类 如:数据集有4个类,加上背景总共便为五个类)
prob = F.softmax(torch.tensor(roi_scores), dim=1)
raw_cls_bbox = cls_bbox.cpu().numpy()
raw_prob = prob.cpu().numpy()
取出属于这一类的框
cls_bbox_l = raw_cls_bbox[:, l, :]
prob_l = raw_prob[:, l]
判断是否大于score_thresh大于保留
mask = prob_l > score_thresh
cls_bbox_l = cls_bbox_l[mask]


画框的过程
for i, c in enumerate(label):predicted_class = self.class_names[int(c)]score = conf[i]left, top, right, bottom = bbox[i]top = top - 5left = left - 5bottom = bottom + 5right = right + 5top = max(0, np.floor(top + 0.5).astype('int32'))left = max(0, np.floor(left + 0.5).astype('int32'))bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))# 画框框label = '{} {:.2f}'.format(predicted_class, score)draw = ImageDraw.Draw(image)label_size = draw.textsize(label, font)label = label.encode('utf-8')print(label)if top - label_size[1] >= 0:text_origin = np.array([left, top - label_size[1]])else:text_origin = np.array([left, top + 1])for i in range(thickness):draw.rectangle([left + i, top + i, right - i, bottom - i],outline=self.colors[int(c)])draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)],fill=self.colors[int(c)])draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)del drawprint("time:",time.time()-start_time)return image
