Finetune basics
DL models can be regarded as a feature extractor, for example, Convolution Neural Networks (CNN) can be regarded as a visual feature extractor. However, this feature extractor needs extensive training to avoid over-fitting and achieve better generalization. If we directly build a model and then train on our own small data set, it is easy to overfit. At this time, we can use the model trained on large datasets of similar tasks and finetune on our own small datasets, which saves a lot of training time and achieves better generalized performance.
Usage scenario
For example, for image classification tasks, we can use models trained on ImageNet datasets, such as MobilenetV2, and take the feature extraction part and replace the final classification layer (ImageNet has 1000 categories, our own dataset may have only 10 categories), and then only the replaced classification layer will be trained. This is because the feature extraction part of MobilenetV2 has been fully trained, and the features extracted by these feature extractors are universal for other images. There are many other tasks, for example, NLP can be used in the BERT model trained on a large corpus to perform finetune on its own corpus.
MNN Finetune example
The following example shows how to use finetune in MNN by using MobilenetV2 on its own 4-category small dataset. The relevant codes are in MNN_ROOT/tools/train/source/demo/mobilenetV2Train.cpp
and MNN_ROOT/tools/train/source/demo/mobilenetV2Utils.cpp
, you can choose a larger learning rate, such as 0.001 to speed up learning.
Note that the MNN model of MobilenetV2 is required in this demo.
// mobilenetV2Train.cpp
class MobilenetV2TransferModule : public Module {
public:
MobilenetV2TransferModule(const char* fileName) {
// Reads the original MobilenetV2 model
auto varMap = Variable::loadMap(fileName);
// MobilenetV2 input node
auto input = Variable::getInputAndOutput(varMap).first.begin()->second;
// MobilenetV2 node before classification layer,i.e. AveragePooling
auto lastVar = varMap["MobilenetV2/Logits/AvgPool"];
// Initialize a 4-category FC layer, represented using a Convolution layer.
NN::ConvOption option;
option.channel = {1280, 4};
mLastConv = std::shared_ptr<Module>(NN::Conv(option));
// Initialize the internal feature extractor. The internal feature extractor does not need training.
mFix.reset(PipelineModule::extract({input}, {lastVar}, false));
// Note we only initialize the 4-category FC layer and only this layer is updated during training.
registerModel({mLastConv});
}
virtual std::vector<VARP> onForward(const std::vector<VARP>& inputs) override {
// Obtain an image and gets output from the MobileNetV2 feature extractor.
auto pool = mFix->forward(inputs[0]);
// Feed the features extracted with the above feature extractor into the 4-category FC layer.
auto result = _Softmax(_Reshape(_Convert(mLastConv->forward(pool), NCHW), {0, -1}));
return {result};
}
// The MobileNetV2 feature extractor: input node to the last AveragePooling layer.
std::shared_ptr<Module> mFix;
// The last 4-category FC layer.
std::shared_ptr<Module> mLastConv;
};
class MobilenetV2Transfer : public DemoUnit {
public:
virtual int run(int argc, const char* argv[]) override {
if (argc < 6) {
std::cout << "usage: ./runTrainDemo.out MobilentV2Transfer /path/to/mobilenetV2Model path/to/train/images/ path/to/train/image/txt path/to/test/images/ path/to/test/image/txt"
<< std::endl;
return 0;
}
std::string trainImagesFolder = argv[2];
std::string trainImagesTxt = argv[3];
std::string testImagesFolder = argv[4];
std::string testImagesTxt = argv[5];
// Reads the model and replace the last classification layer.
std::shared_ptr<Module> model(new MobilenetV2TransferModule(argv[1]));
// Begin training.
MobilenetV2Utils::train(model, 4, 0, trainImagesFolder, trainImagesTxt, testImagesFolder, testImagesTxt);
return 0;
}
};