模型验证
import torchimport torchvisionfrom PIL import Imagefrom torch import nnimage_path = "image/dog3.jpg"img = Image.open(image_path)# print(img)img = img.convert("RGB")transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)), torchvision.transforms.ToTensor()])img = transform(img)class Lcy(nn.Module): def __init__(self): super(Lcy,self).__init__() self.model = nn.Sequential( nn.Conv2d(3,32,5,1,2), nn.MaxPool2d(2), nn.Conv2d(32,32,5,1,2), nn.MaxPool2d(2), nn.Conv2d(32,64,5,1,2), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(64*4*4,64), nn.Linear(64,10) ) def forward(self,input): output = self.model(input) return output# model = torch.load("lcy.pth")model = torch.load("model_dir/l_29_gpu.pth")# print(model)# print(img.type)img = torch.reshape(img,(1,3,32,32))img = img.cuda()model.eval()with torch.no_grad(): output = model(img)# print(output)print(output.argmax(1))args = ['飞机','汽车','鸟','猫','鹿','狗','青蛙','马','船','卡车']if output.argmax(1) == 0: print("这张图片是{}".format(args[0]))elif output.argmax(1) == 1: print("这张图片是{}".format(args[1]))elif output.argmax(1) == 2: print("这张图片是{}".format(args[2]))elif output.argmax(1) == 3: print("这张图片是{}".format(args[3]))elif output.argmax(1) == 4: print("这张图片是{}".format(args[4]))elif output.argmax(1) == 5: print("这张图片是{}".format(args[5]))elif output.argmax(1) == 6: print("这张图片是{}".format(args[6]))elif output.argmax(1) == 7: print("这张图片是{}".format(args[7]))elif output.argmax(1) == 8: print("这张图片是{}".format(args[8]))elif output.argmax(1) == 9: print("这张图片是{}".format(args[9]))
