精读yolov3源码时发现了一个写法,研究了下学到了新的知识点。

PS:这个代码写的真的蛮好的,有很多小细节,学到的东西也多

utils.py源码内容

  1. # iou: torch.tensor(), size(num_of_labels(an image), )
  2. # iou_thres: 一个小数
  3. # t: torch.tensor(), size(num_of_labels(an image), 6)
  4. j = iou > iou_thres
  5. t = t[j]

第一次见到tensor1[tensor2]这种结构,做了些测试

tensor2为torch.uint8时

  1. targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
  2. j = torch.tensor([0, 1, 0, 0, 1, 0, 1],dtype=torch.uint8)
  3. t = targets[j]
  4. print(t)
  5. '''
  6. 输出:
  7. tensor([[ 3, 4],
  8. [ 9, 10],
  9. [13, 14]])
  10. '''
  11. targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
  12. j = torch.tensor([0, 1, 0, 0, 1, -1, 1],dtype=torch.uint8)
  13. t = targets[j]
  14. print(t)
  15. '''
  16. 输出:
  17. tensor([[ 3, 4],
  18. [ 9, 10],
  19. [11, 12],
  20. [13, 14]])
  21. '''

结论1
当tensor2为uint8类型时,tensor1[tensor2]的结果为tensor2不为0元素位置对应的tensor1元素

tensor2不为torch.uint8时

进一步拓展
当tensor2不为uint8类型时结果会怎么样呢?

  1. targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
  2. j = torch.tensor([0, 1, 0, 0, 1, 0, 1])
  3. t = targets[j]
  4. print(t)
  5. print(j.dtype)
  6. '''
  7. 输出:
  8. tensor([[1, 2],
  9. [3, 4],
  10. [1, 2],
  11. [1, 2],
  12. [3, 4],
  13. [1, 2],
  14. [3, 4]])
  15. torch.int64
  16. '''

结论2
此时tensor2的元素表示的是位置
torch.tensor不指定type时为int64类型