完整的模型训练
import torchvisionfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriterfrom model import *train_data = torchvision.datasets.CIFAR10(root="download_data", train=True, transform=torchvision.transforms.ToTensor(), download=True)test_data = torchvision.datasets.CIFAR10(root="download_data", train=False, transform=torchvision.transforms.ToTensor(), download=True)data_size = len(train_data)test_data_size = len(test_data)print(data_size)print(test_data_size)# 格式化字符串print("数据集长度:{}".format(data_size))data_loader = DataLoader(train_data,batch_size=64)test_loader = DataLoader(test_data,batch_size=64)# 创建网络模型l = Lcy()# 定义损失函数loss_function = nn.CrossEntropyLoss()# 优化器# learning_rate = 0.01# 1e-2 : 1 x (10)^(-2) = 1 x 0.01 = 0.01# 两种写法表达的数字相同learning_rate = 1e-2optimizer = torch.optim.SGD(l.parameters(),lr=learning_rate)# 设置初始化参数# 记录训练的次数total_train_step = 0# 记录测试的次数total_test_step = 0# 训练的轮数epoch = 10writer = SummaryWriter("logs")for i in range(epoch): print("-------第{}轮训练-------".format(i+1)) # 训练的步骤 for data in data_loader: img, targets = data output = l(img) loss = loss_function(output,targets) optimizer.zero_grad() loss.backward() optimizer.step() total_test_step += 1 if total_test_step % 100 == 0: print("训练次数{},loss:{}".format(total_test_step,loss)) writer.add_scalar("train_loss",loss, total_test_step) #测试步骤: total_test_loss = 0 with torch.no_grad(): for data in test_loader: img, targets = data output = l(img) loss = loss_function(output,targets) total_test_loss += loss.item() print("整体测试上的损失:{}".format(total_test_loss)) writer.add_scalar("test_loss",total_test_loss,total_test_step)writer.close()
import torchfrom torch import nnclass Lcy(nn.Module): def __init__(self): super(Lcy,self).__init__() self.model = nn.Sequential( nn.Conv2d(3,32,5,1,2), nn.MaxPool2d(2), nn.Conv2d(32,32,5,1,2), nn.MaxPool2d(2), nn.Conv2d(32,64,5,1,2), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(64*4*4,64), nn.Linear(64,10) ) def forward(self,input): output = self.model(input) return outputif __name__ == '__main__': l = Lcy() input = torch.ones((64,3,32,32)) output = l(input) print(output.shape)