image.png

解码:

1、求出热力图中每个特征点的所有类别中的最大概率;
2、求出每一列在步骤1中得出的特征点的类别最大概率的最大概率(这一列所有特征点类别最大概率的最大概率)对应的特征点索引(在这一列中的索引);
3、通过步骤2的最大概率特征点列索引,将2D-CTC的热力图转换为预测的最有可能路径的1D-CTC,后续按照1D-CTC的解码进行解码即可;

  1. # classify: (N, C, H, W)
  2. # mask: (N, 1, H, W)
  3. heatmap = classify * mask # (N, C, H, W)
  4. paths = heatmap.max(1, keepdim=True)[0].argmax(2, keepdim=True) # (N, 1, 1, W)
  5. C = classify.size(1)
  6. paths = paths.repeat(1, C, 1, 1) # (N, C, 1, W)
  7. selected_probabilities = heatmap.gather(2, paths) # (N, C, W)
  8. pred = selected_probabilities.argmax(1).squeeze(1) # (N, W)

LOSS:

编译:
python setup.py build
报错:
error: identifier “AT_CHECK” is undefined
解决方法:
将报错行中的AT_CHECK替换为TORCH_CHECK即可.