#训练函数部分
def train(i,train_iter, test_iter, net, loss, optimizer, device, num_epochs):
net = net.to(device)
print("training on ", device)
start = time.time()
test_acc_max_l = []
train_acc_max_l = []
train_l_min_l=[]
for epoch in range(num_epochs): #迭代100次
batch_count = 0
train_l_sum, train_acc_sum, test_acc_sum, n = 0.0, 0.0, 0.0, 0
for X, y in train_iter:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_l_sum += l.cpu().item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
#至此,一个epoches完成
test_acc_sum= d2l.evaluate_accuracy(test_iter, net)
train_l_min_l.append(train_l_sum/batch_count)
train_acc_max_l.append(train_acc_sum/n)
test_acc_max_l.append(test_acc_sum)
print('fold %d epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (i+1,epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc_sum))
#train_l_min_l.sort()
#¥train_acc_max_l.sort()
index_max=test_acc_max_l.index(max(test_acc_max_l))
f = open("./shallow/results.txt", "a")
if i==0:
f.write("%d fold"+" "+"train_loss"+" "+"train_acc"+" "+"test_acc")
f.write('\n' +"fold"+str(i+1)+":"+str(train_l_min_l[index_max]) + " ;" + str(train_acc_max_l[index_max]) + " ;" + str(test_acc_max_l[index_max]))
f.close()
print('fold %d, train_loss_min %.4f, train acc max%.4f, test acc max %.4f, time %.1f sec'
% (i + 1, train_l_min_l[index_max], train_acc_max_l[index_max], test_acc_max_l[index_max], time.time() - start))
return train_l_min_l[index_max],train_acc_max_l[index_max],test_acc_max_l[index_max]