一、torchvision中数据集的使用
1. torchvision基本构成
- torchvision.datasets: 一些加载数据的函数及常用的数据集接口
- torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等
- torchvision.transforms:常用的图片变换,例如裁剪、旋转等
- torchvision.utils: 其他的一些有用的方法
2. CIFAR10数据集的使用
- 导包:
import torchvision 使用transform转换数据
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
设置训练集和测试集
# root参数是数据集存放位置或者下载后存放的位置# train参数表明是否是训练集,true为训练集,false为测试集# transform参数为对数据进行的变换# download参数表示是否要从网络上下载train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)# 设置测试集test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
tips:
- windows中查看当前路径:
chdir - 如果自己下载数据集则需将数据集放到root目录下
- 提取图像和标签
二、 DataLoader的使用
- 导包:
from torch.utils.data import DataLoader - 通过dataset构建加载器
```python
test_data = torchvision.datasets.CIFAR10(“./CIFAR10”,
train=False,transform=torchvision.transforms.ToTensor())
dataset表示数据集
batch_size=4对数据读取的作用,表示一次性读取数据集中的4张图片,集合在一起进行返回
shuffle为打乱数据
“”” num_workers为DataLoader一次性创建多少工作进程 num_workers=0表示只有主进程去加载batch数据,这个可能会是一个瓶颈。 num_workers = 1表示只有一个worker进程用来加载batch数据,而主进程是不参与数据加载的。 num_workers>0 表示只有指定数量的worker进程去加载数据,主进程不参与。增加num_works也同时会增加cpu内存的消耗。所以num_workers的值依赖于 batch size和机器性能。 一般开始是将num_workers设置为等于计算机上的CPU数量 最好的办法是缓慢增加num_workers,直到训练速度不再提高,就停止增加num_workers的值。 “””
drop_last为最后未完成的batch来说,True为丢弃
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
- Batch_size的使用:```pythonfor data in test_loader:imgs, targets = dataprint(imgs.shape)print(targets)

