数组部分屏蔽
# image_pred_ shape is 6300 * 85
cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1)
上式代码直接实现了,对二维数组中imagepred[:,-1]不为cls的行进行屏蔽(变为0)。
class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
image_pred_class = image_pred_[class_mask_ind].view(-1,7)
先对索引进行筛选,然后利用索引筛选原数组。