利用GPU训练
方法一
- 修改网络类
l = Lcy()if torch.cuda.is_available(): l.cuda()
- 修改损失函数
loss_function = nn.CrossEntropyLoss()if torch.cuda.is_available(): loss_function.cuda()
- 修改数据
for data in dataloader: img,target = data if torch.cuda.is_available(): img = img.cuda() target = target.cuda()
方法二
- 设置设备
device = device("cpu")device = device("cuda")device = device("cuda:0")device = device("cuda" if torch.cuda.is_available() else "cpu")
- 修改网络类
l = Lcy()l = l.to(device)
- 修改损失函数
loss_function = nn.CrossEntropyLoss()loss_function = loss_function.to(device)
- 修改数据
for data in dataloader: img,target = data img = img.to(device) target = target.to(device)