模型验证
import torch
import torchvision
from PIL import Image
from torch import nn
image_path = "image/dog3.jpg"
img = Image.open(image_path)
# print(img)
img = img.convert("RGB")
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()])
img = transform(img)
class Lcy(nn.Module):
def __init__(self):
super(Lcy,self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3,32,5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32,32,5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32,64,5,1,2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*4*4,64),
nn.Linear(64,10)
)
def forward(self,input):
output = self.model(input)
return output
# model = torch.load("lcy.pth")
model = torch.load("model_dir/l_29_gpu.pth")
# print(model)
# print(img.type)
img = torch.reshape(img,(1,3,32,32))
img = img.cuda()
model.eval()
with torch.no_grad():
output = model(img)
# print(output)
print(output.argmax(1))
args = ['飞机','汽车','鸟','猫','鹿','狗','青蛙','马','船','卡车']
if output.argmax(1) == 0:
print("这张图片是{}".format(args[0]))
elif output.argmax(1) == 1:
print("这张图片是{}".format(args[1]))
elif output.argmax(1) == 2:
print("这张图片是{}".format(args[2]))
elif output.argmax(1) == 3:
print("这张图片是{}".format(args[3]))
elif output.argmax(1) == 4:
print("这张图片是{}".format(args[4]))
elif output.argmax(1) == 5:
print("这张图片是{}".format(args[5]))
elif output.argmax(1) == 6:
print("这张图片是{}".format(args[6]))
elif output.argmax(1) == 7:
print("这张图片是{}".format(args[7]))
elif output.argmax(1) == 8:
print("这张图片是{}".format(args[8]))
elif output.argmax(1) == 9:
print("这张图片是{}".format(args[9]))
![15.jpg](/uploads/projects/mrliucy@rg7go6/49fa6de8417177347927a206b190f775.jpeg)