数组部分屏蔽

  1. # image_pred_ shape is 6300 * 85
  2. cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1)

上式代码直接实现了,对二维数组中imagepred[:,-1]不为cls的行进行屏蔽(变为0)。

  1. class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
  2. image_pred_class = image_pred_[class_mask_ind].view(-1,7)

先对索引进行筛选,然后利用索引筛选原数组。