参考来源:
简书:torch.max() 使用讲解

在分类问题中,通常需要使用 **max()** 函数对 **softmax** 函数的输出值进行操作,求出预测值索引,然后与标签进行比对,计算准确率。下面讲解一下 **torch.max()** 函数的输入及输出值都是什么,便于我们理解该函数。

1. torch.max(input, dim) 函数

output = torch.max(input, dim)

输入

  • **input****softmax** 函数输出的一个 **tensor**
  • **dim****max** 函数索引的维度 **0/1****0** 是每列的最大值,**1** 是每行的最大值。

输出

  • 函数会返回两个** tensor**,第一个 **tensor** 是每 行/列 的最大值;第二个 **tensor** 是每 行/列 最大值的索引。

在多分类任务中我们并不需要知道各类别的预测概率,所以返回值的第一个 **tensor** 对分类任务没有帮助,而第二个 **tensor** 包含了预测最大概率的索引,所以在实际使用中我们仅获取第二个 **tensor** 即可。

下面通过一个实例可以更容易理解这个函数的用法。

  1. import torch
  2. a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
  3. print(a)

输出:

  1. tensor([[ 1, 5, 62, 54],
  2. [ 2, 6, 2, 6],
  3. [ 2, 65, 2, 6]])

索引每行的最大值:

  1. torch.max(a, 1)

输出:

  1. torch.return_types.max(
  2. values=tensor([62, 6, 65]),
  3. indices=tensor([2, 3, 1]))

在计算准确率时第一个 **tensor values** 是不需要的,所以我们只需提取第二个 **tensor** ,并将 **tensor** 格式的数据转换成 **array** 格式。

  1. torch.max(a, 1)[1].numpy()

输出:

  1. array([2, 3, 1], dtype=int64)

这样,我们就可以与标签值进行比对,计算模型预测准确率。
注:在有的地方我们会看到 **torch.max(a, 1).data.numpy()** 的写法,这是因为在早期的 pytorch 的版本中, variable 变量和 tenosr 是不一样的数据格式,variable 可以进行反向传播,tensor 不可以,需要将 variable 转变成 tensor 再转变成 numpy 。现在的版本已经将 variabletenosr 合并,所以只用 **torch.max(a,1).numpy()** 就可以了。

2. 准确率的计算

  1. pred_y = torch.max(predict, 1)[1].numpy()
  2. label_y = torch.max(label, 1)[1].data.numpy()
  3. accuracy = (pred_y == label_y).sum() / len(label_y)

**predict**softmax 函数输出
**label**:样本标签,这里假设它是 one-hot 编码