离散的对应索引

    1. y = torch.tensor([0,2])
    2. y_hat = torch.tensor([[0.1,0.2,0.5],[0.2,0.5,0.3]])
    3. y_hat[[0,1],y] # 第一个参数 取所有行 ,第二个参数取所有列
    4. Out[10]: tensor([0.1000, 0.3000])
    5. y_hat[[0,1],[1,1]]
    6. Out[11]: tensor([0.2000, 0.5000])