一、torchvision中数据集的使用

1. torchvision基本构成

  • torchvision.datasets: 一些加载数据的函数及常用的数据集接口
  • torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等
  • torchvision.transforms:常用的图片变换,例如裁剪、旋转等
  • torchvision.utils: 其他的一些有用的方法

    2. CIFAR10数据集的使用

  1. 导包:import torchvision
  2. 使用transform转换数据

    1. dataset_transform = torchvision.transforms.Compose([
    2. torchvision.transforms.ToTensor()
    3. ])
  3. 设置训练集和测试集

    1. # root参数是数据集存放位置或者下载后存放的位置
    2. # train参数表明是否是训练集,true为训练集,false为测试集
    3. # transform参数为对数据进行的变换
    4. # download参数表示是否要从网络上下载
    5. train_set = torchvision.datasets.CIFAR10(root="./dataset",
    6. train=True,
    7. transform=dataset_transform,
    8. download=True)
    9. # 设置测试集
    10. test_set = torchvision.datasets.CIFAR10(root="./dataset",
    11. train=False,
    12. transform=dataset_transform,
    13. download=True)

tips:

  • windows中查看当前路径:chdir
  • 如果自己下载数据集则需将数据集放到root目录下

  1. 提取图像和标签

image.png

二、 DataLoader的使用

  1. 导包:from torch.utils.data import DataLoader
  2. 通过dataset构建加载器 ```python test_data = torchvision.datasets.CIFAR10(“./CIFAR10”,
    1. train=False,
    2. transform=torchvision.transforms.ToTensor())

dataset表示数据集

batch_size=4对数据读取的作用,表示一次性读取数据集中的4张图片,集合在一起进行返回

shuffle为打乱数据

“”” num_workers为DataLoader一次性创建多少工作进程 num_workers=0表示只有主进程去加载batch数据,这个可能会是一个瓶颈。 num_workers = 1表示只有一个worker进程用来加载batch数据,而主进程是不参与数据加载的。 num_workers>0 表示只有指定数量的worker进程去加载数据,主进程不参与。增加num_works也同时会增加cpu内存的消耗。所以num_workers的值依赖于 batch size和机器性能。 一般开始是将num_workers设置为等于计算机上的CPU数量 最好的办法是缓慢增加num_workers,直到训练速度不再提高,就停止增加num_workers的值。 “””

drop_last为最后未完成的batch来说,True为丢弃

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

  1. - Batch_size的使用:
  2. ```python
  3. for data in test_loader:
  4. imgs, targets = data
  5. print(imgs.shape)
  6. print(targets)

image.png