解码:
1、求出热力图中每个特征点的所有类别中的最大概率;
2、求出每一列在步骤1中得出的特征点的类别最大概率的最大概率(这一列所有特征点类别最大概率的最大概率)对应的特征点索引(在这一列中的索引);
3、通过步骤2的最大概率特征点列索引,将2D-CTC的热力图转换为预测的最有可能路径的1D-CTC,后续按照1D-CTC的解码进行解码即可;
# classify: (N, C, H, W)
# mask: (N, 1, H, W)
heatmap = classify * mask # (N, C, H, W)
paths = heatmap.max(1, keepdim=True)[0].argmax(2, keepdim=True) # (N, 1, 1, W)
C = classify.size(1)
paths = paths.repeat(1, C, 1, 1) # (N, C, 1, W)
selected_probabilities = heatmap.gather(2, paths) # (N, C, W)
pred = selected_probabilities.argmax(1).squeeze(1) # (N, W)
LOSS:
编译:
python setup.py build
报错:
error: identifier “AT_CHECK” is undefined
解决方法:
将报错行中的AT_CHECK替换为TORCH_CHECK即可.