在学习pytorch时,编写如下代码运行报错

    1. from __future__ import print_function, division
    2. import torch
    3. import torch.nn as nn
    4. import torch.optim as optim
    5. from torch.optim import lr_scheduler
    6. import numpy as np
    7. import torchvision
    8. from torchvision import datasets, models, transforms
    9. import matplotlib.pyplot as plt
    10. import time
    11. import os
    12. import copy
    13. plt.ion() # interactive mode
    14. """加载数据"""
    15. # 对训练集和验证集数据进行裁切和归一化
    16. data_transforms = {
    17. 'train': transforms.Compose([
    18. transforms.Resize((224, 224)),
    19. transforms.ToTensor(),
    20. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    21. ]),
    22. 'val': transforms.Compose([
    23. transforms.Resize((224, 224)),
    24. transforms.ToTensor(),
    25. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    26. ]),
    27. }
    28. data_dir = 'D:\\AAAAASunspots'
    29. image_datasets = {x: datasets.ImageFolder(os.path.join(os.path.join(data_dir, x), 'continuum'),
    30. data_transforms[x])
    31. for x in ['train', 'val']}
    32. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32,
    33. shuffle=True, num_workers=4)
    34. for x in ['train', 'val']}
    35. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    36. class_names = image_datasets['train'].classes
    37. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    38. """可视化部分图像数据"""
    39. def imshow(inp, title=None):
    40. """Imshow for Tensor."""
    41. inp = inp.numpy().transpose((1, 2, 0))
    42. mean = np.array([0.5, 0.5, 0.5])
    43. std = np.array([0.5, 0.5, 0.5])
    44. inp = std * inp + mean
    45. inp = np.clip(inp, 0, 1) # 将元素大小限定在0和1之间
    46. plt.imshow(inp)
    47. if title is not None:
    48. plt.title(title)
    49. plt.pause(0.001) # pause a bit so that plots are updated
    50. # 获取一批训练数据
    51. inputs, classes = next(iter(dataloaders['train']))
    52. # 批量制作网格
    53. out = torchvision.utils.make_grid(inputs)
    54. imshow(out, title=[class_names[x] for x in classes])

    image.png
    错误分析:多进程要在main函数中才能运行
    因此,可以将上述代码放到main函数中运行;或者将num_workers改为0,单进程加载。
    参考:RuntimeError: An attempt has been made to start a new process before the current process…