完整的模型训练
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import *
train_data = torchvision.datasets.CIFAR10(root="download_data", train=True,
transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10(root="download_data", train=False,
transform=torchvision.transforms.ToTensor(), download=True)
data_size = len(train_data)
test_data_size = len(test_data)
print(data_size)
print(test_data_size)
# 格式化字符串
print("数据集长度:{}".format(data_size))
data_loader = DataLoader(train_data,batch_size=64)
test_loader = DataLoader(test_data,batch_size=64)
# 创建网络模型
l = Lcy()
# 定义损失函数
loss_function = nn.CrossEntropyLoss()
# 优化器
# learning_rate = 0.01
# 1e-2 : 1 x (10)^(-2) = 1 x 0.01 = 0.01
# 两种写法表达的数字相同
learning_rate = 1e-2
optimizer = torch.optim.SGD(l.parameters(),lr=learning_rate)
# 设置初始化参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10
writer = SummaryWriter("logs")
for i in range(epoch):
print("-------第{}轮训练-------".format(i+1))
# 训练的步骤
for data in data_loader:
img, targets = data
output = l(img)
loss = loss_function(output,targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_test_step += 1
if total_test_step % 100 == 0:
print("训练次数{},loss:{}".format(total_test_step,loss))
writer.add_scalar("train_loss",loss, total_test_step)
#测试步骤:
total_test_loss = 0
with torch.no_grad():
for data in test_loader:
img, targets = data
output = l(img)
loss = loss_function(output,targets)
total_test_loss += loss.item()
print("整体测试上的损失:{}".format(total_test_loss))
writer.add_scalar("test_loss",total_test_loss,total_test_step)
writer.close()
import torch
from torch import nn
class Lcy(nn.Module):
def __init__(self):
super(Lcy,self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3,32,5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32,32,5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32,64,5,1,2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*4*4,64),
nn.Linear(64,10)
)
def forward(self,input):
output = self.model(input)
return output
if __name__ == '__main__':
l = Lcy()
input = torch.ones((64,3,32,32))
output = l(input)
print(output.shape)