内部已经包含训练模块

    1. #k折交叉验证部分
    2. def k_fold(k,image_dir,num_epochs,device,batch_size,optimizer,loss,net):
    3. train_k = './/shallow//train_k.txt'
    4. test_k = './/shallow//test_k.txt'
    5. #loss_acc_sum,train_acc_sum, test_acc_sum = 0,0,0
    6. Ktrain_min_l = []
    7. Ktrain_acc_max_l = []
    8. Ktest_acc_max_l = []
    9. for i in range(k):
    10. get_k_fold_data(k, i,image_dir)
    11. #修改train函数,使其返回每一批次的准确率,tarin_ls用列表表示
    12. train_data = MyDataset(is_train=True, root=train_k)
    13. test_data = MyDataset(is_train=False, root=test_k)
    14. train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=10, shuffle=True, num_workers=num_workers)
    15. test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=10, shuffle=True, num_workers=num_workers)
    16. loss_min,train_acc_max,test_acc_max=train(i,train_loader,test_loader, net, loss, optimizer, device, num_epochs)
    17. Ktrain_min_l.append(loss_min)
    18. Ktrain_acc_max_l.append(train_acc_max)
    19. Ktest_acc_max_l.append(test_acc_max)
    20. #train_acc_sum += train_acc# train函数epoches(即第k个数据集被测试后)结束后,累加
    21. #test_acc_sum += test_acc#
    22. #loss_acc_sum+=loss_acc
    23. #print('fold %d, lose_rmse_max %.4f, train_rmse_max %.4f, test_rmse_max %.4f ' %(i+1, loss_acc,train_acc, test_acc_max_l[i]))
    24. return sum(Ktrain_min_l)/len(Ktrain_min_l),sum(Ktrain_acc_max_l)/len(Ktrain_acc_max_l),sum(Ktest_acc_max_l)/len(Ktest_acc_max_l)