代码示例:
# 切分数据集
full_data = torchvision.datasets.ImageFolder('path',transform=transforms)
train_size = int(0.8*len(full_data))
test_size = len(full_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_data,[train_size,test_size])
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=,shuffle=True,num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=,shuffle=False,num_workers=0)