此示例显示了如何使用转移学习来对SqueezeNet(一种预训练的卷积神经网络)进行再训练,以对一组新图像进行分类。尝试以下示例,了解开始使用MATLAB®进行深度学习有多么简单。
转移学习通常用于深度学习应用程序。您可以采用经过预训练的网络,并将其用作学习新任务的起点。与使用从头开始随机初始化的权重训练网络相比,使用迁移学习对网络进行微调通常会更快,更容易。您可以使用少量训练图像将学习到的功能快速转移到新任务中。
图1. 使用预训练模型的过程
提取数据
在工作空间中,提取出MerchData数据集的数据。这个数据包含了75张货物商品的照片,这75个照片属于5个类。(cap, cube, playing cards, screwdriver, torch)。
unzip("MerchData.zip");
载入训练好的网络
打开深度网络设计器
deepNetworkDesigner
从预训练好的列表中选取SqueezeNet并且打开
图1. 深度学习designer
深度学习网络设计器展示了整个网络的结构
图2. 网络的结构
探索网络图。要使用鼠标放大,请使用Ctrl +滚轮。要平移,请使用箭头键,或按住滚轮并拖动鼠标。选择一个图层以查看其属性。取消选择所有层以在“属性”窗格中查看网络摘要。
导入数据
要将数据加载到Deep Network Designer中,请在“数据”选项卡上,单击“导入数据” >“导入图像数据”。“导入图像数据”对话框打开。
在数据源列表中,选择文件夹。单击浏览,然后选择提取的MerchData文件夹。
将数据分为70%训练数据和30%验证数据。
指定要在训练图像上执行的增强操作。数据增强有助于防止网络过度拟合和记忆训练图像的确切细节。对于此示例,在x轴上应用随机反射,在[-90,90]度范围内进行随机旋转,并在[1,2]范围内进行随机缩放。
图3. 导入数据
单击导入将数据导入到深度网络设计器中。
改变网络参数进行迁移学习
要训练SqueezeNet对新图像进行分类,请替换网络的最后一个2-D卷积层和最后一个分类层。在SqueezeNet中,这些层分别具有名称'conv10'
和 'ClassificationLayer_predictions'
。
在设计器窗格上,将新的拖动convolution2dLayer
到画布上。要匹配原始卷积层,请设置FilterSize
为1,1
。编辑NumFilters
为新数据中的类数,在本示例中为5
。
改变学习率,这样的学习是在新的层比在转移层更快的设置WeightLearnRateFactor
和BiasLearnRateFactor
对10
。
删除最后一个二维卷积层,然后连接新层。
图3. 网络参数
更换输出层。滚动到“层库”的末尾,然后将新的对象classificationLayer
拖到画布上。删除原始输出层,并在其位置连接新层。
图4. 更换输出层
训练网络
要选择训练选项,请选择训练选项卡,然后单击训练选项。将初始学习速率设置为较小的值会减慢在传输的层中的学习速度。在上一步中,您增加了2D卷积层的学习速率因子,以加快在新的最终层中的学习速度。学习速率设置的这种组合导致仅在新层中进行快速学习,而在其他层中进行较慢的学习。
在这个例子中,设置InitialLearnRate来0.0001
,ValidationFrequency到5
,MaxEpochs到8
。由于有55个观测值,请将MiniBatchSize设置为11
均匀划分训练数据,并确保在每个时期使用整个训练集。
图5. 训练网络选项
要使用指定的训练选项训练网络,请单击“关闭”,然后单击“训练”。深度网络设计器使您可以可视化并监视培训进度。然后,您可以根据需要编辑训练选项并重新训练网络。
图6. 训练网络
导出结果生成MATLAB代码
要从培训中导出结果,请在“训练”选项卡上,选择“导出” >“导出训练好的网络和结果”。深度网络设计器将训练后的网络导出为变量trainedNetwork_1
,将训练信息导出为变量trainInfoStruct_1
。
您还可以生成MATLAB代码,从而重新创建网络和使用的培训选项。在训练选项卡上,选择 导出>生成训练代码。检查MATLAB代码,以了解如何以编程方式准备要训练的数据,创建网络体系结构和训练网络。
分类新图像
加载新图像以使用训练有素的网络进行分类。
I = imread("MerchDataTest.jpg");
调整测试图像的大小使得能够匹配网络的输入
I = imresize(I, [227 227]);
使用训练好的网络对测试集进行分类
[YPred,probs] = classify(trainedNetwork_1,I);
imshow(I)
label = YPred;
title(string(label) + ", " + num2str(100*max(probs),3) + "%");
图7. 预测结果