方法一:(我没用过)
    image.png
    方法二:(我用过第二种,成功了)
    image.png

    代码示例:

    1. # 切分数据集
    2. full_data = torchvision.datasets.ImageFolder('path',transform=transforms)
    3. train_size = int(0.8*len(full_data))
    4. test_size = len(full_data) - train_size
    5. train_dataset, test_dataset = torch.utils.data.random_split(full_data,[train_size,test_size])
    6. train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=,shuffle=True,num_workers=0)
    7. test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=,shuffle=False,num_workers=0)