Pytorch

1、安装

  • CPU
  1. conda create --name pyhorch python=3.8
  2. activate pyhorch
  3. pip install torch torchvision torchaudio
  4. pip install jupyter ## win 跳出选择是否的框,必须选择‘否’
  5. pip install matplotlib
  6. # 安装主题
  7. pip install jupyterthemes
  8. # 查看主题
  9. jt -l
  10. # 主题切换
  11. jt -t chesterish
  12. # 启动
  13. jupyter notebook --ip=127.0.0.1 --port=8000
  • GPU

https://blog.csdn.net/Dahat/article/details/117366876
  1. # 返回当前设备索引
  2. torch.cuda.current_device()
  3. # 返回GPU的数量
  4. torch.cuda.device_count()
  5. # 返回gpu名字,设备索引默认从0开始
  6. torch.cuda.get_device_name(0)
  7. # cuda是否可用
  8. torch.cuda.is_available()

Test
  1. import torch
  2. from torch.utils.data import Dataset
  3. from torchvision import datasets
  4. from torchvision.transforms import ToTensor
  5. import matplotlib.pyplot as plt
  6. training_data = datasets.FashionMNIST(
  7. root="data",
  8. train=True,
  9. download=True,
  10. transform=ToTensor()
  11. )
  12. test_data = datasets.FashionMNIST(
  13. root="data",
  14. train=False,
  15. download=True,
  16. transform=ToTensor()
  17. )
  18. labels_map = {
  19. 0: "T-Shirt",
  20. 1: "Trouser",
  21. 2: "Pullover",
  22. 3: "Dress",
  23. 4: "Coat",
  24. 5: "Sandal",
  25. 6: "Shirt",
  26. 7: "Sneaker",
  27. 8: "Bag",
  28. 9: "Ankle Boot",
  29. }
  30. figure = plt.figure(figsize=(8, 8))
  31. cols, rows = 3, 3
  32. for i in range(1, cols * rows + 1):
  33. sample_idx = torch.randint(len(training_data), size=(1,)).item()
  34. img, label = training_data[sample_idx]
  35. figure.add_subplot(rows, cols, i)
  36. plt.title(labels_map[label])
  37. plt.axis("off")
  38. plt.imshow(img.squeeze(), cmap="gray")
  39. plt.show()

错误

  1. '''
  2. Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
  3. '''
  4. import os
  5. os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
  6. '''
  7. Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
  8. IOError: image file is truncated (X bytes not processed)
  9. '''
  10. from PIL import ImageFile
  11. ImageFile.LOAD_TRUNCATED_IMAGES = True
  12. '''
  13. 数据集较小时(小于2W)建议num_works不用管默认就行,因为用了反而比没用慢。
  14. 当数据集较大时建议采用,num_works一般设置为(CPU线程数+-1)为最佳,可以用以下代码找出最佳num_works(注意windows用户如果要使用多核多线程必须把训练放在
  15. if __name__ == '__main__':下才不会报错)
  16. '''
  17. # windows 环境下
  18. num_works==0