模型验证

  • 示例代码:
  1. import torch
  2. import torchvision
  3. from PIL import Image
  4. from torch import nn
  5. image_path = "image/dog3.jpg"
  6. img = Image.open(image_path)
  7. # print(img)
  8. img = img.convert("RGB")
  9. transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
  10. torchvision.transforms.ToTensor()])
  11. img = transform(img)
  12. class Lcy(nn.Module):
  13. def __init__(self):
  14. super(Lcy,self).__init__()
  15. self.model = nn.Sequential(
  16. nn.Conv2d(3,32,5,1,2),
  17. nn.MaxPool2d(2),
  18. nn.Conv2d(32,32,5,1,2),
  19. nn.MaxPool2d(2),
  20. nn.Conv2d(32,64,5,1,2),
  21. nn.MaxPool2d(2),
  22. nn.Flatten(),
  23. nn.Linear(64*4*4,64),
  24. nn.Linear(64,10)
  25. )
  26. def forward(self,input):
  27. output = self.model(input)
  28. return output
  29. # model = torch.load("lcy.pth")
  30. model = torch.load("model_dir/l_29_gpu.pth")
  31. # print(model)
  32. # print(img.type)
  33. img = torch.reshape(img,(1,3,32,32))
  34. img = img.cuda()
  35. model.eval()
  36. with torch.no_grad():
  37. output = model(img)
  38. # print(output)
  39. print(output.argmax(1))
  40. args = ['飞机','汽车','鸟','猫','鹿','狗','青蛙','马','船','卡车']
  41. if output.argmax(1) == 0:
  42. print("这张图片是{}".format(args[0]))
  43. elif output.argmax(1) == 1:
  44. print("这张图片是{}".format(args[1]))
  45. elif output.argmax(1) == 2:
  46. print("这张图片是{}".format(args[2]))
  47. elif output.argmax(1) == 3:
  48. print("这张图片是{}".format(args[3]))
  49. elif output.argmax(1) == 4:
  50. print("这张图片是{}".format(args[4]))
  51. elif output.argmax(1) == 5:
  52. print("这张图片是{}".format(args[5]))
  53. elif output.argmax(1) == 6:
  54. print("这张图片是{}".format(args[6]))
  55. elif output.argmax(1) == 7:
  56. print("这张图片是{}".format(args[7]))
  57. elif output.argmax(1) == 8:
  58. print("这张图片是{}".format(args[8]))
  59. elif output.argmax(1) == 9:
  60. print("这张图片是{}".format(args[9]))

15.jpg