离散的对应索引 y = torch.tensor([0,2])y_hat = torch.tensor([[0.1,0.2,0.5],[0.2,0.5,0.3]])y_hat[[0,1],y] # 第一个参数 取所有行 ,第二个参数取所有列Out[10]: tensor([0.1000, 0.3000])y_hat[[0,1],[1,1]]Out[11]: tensor([0.2000, 0.5000])