利用GPU训练
方法一
- 修改网络类
l = Lcy()
if torch.cuda.is_available():
l.cuda()
- 修改损失函数
loss_function = nn.CrossEntropyLoss()
if torch.cuda.is_available():
loss_function.cuda()
- 修改数据
for data in dataloader:
img,target = data
if torch.cuda.is_available():
img = img.cuda()
target = target.cuda()
方法二
- 设置设备
device = device("cpu")
device = device("cuda")
device = device("cuda:0")
device = device("cuda" if torch.cuda.is_available() else "cpu")
- 修改网络类
l = Lcy()
l = l.to(device)
- 修改损失函数
loss_function = nn.CrossEntropyLoss()
loss_function = loss_function.to(device)
- 修改数据
for data in dataloader:
img,target = data
img = img.to(device)
target = target.to(device)