1 Obtain a trainable model
First, we need to obtain a trainable model, which can be achieved in two ways in MNN.
First, convert models trained by other frameworks, such as TensorFlow and Pytorch, into trainable models in MNN format. This process can be implemented by using the MNNConvert tool. Typical application scenarios are: 1. Use MNN to finetune a pre-trained model 2. Use MNN for quantization-aware training. You need to use the --forTraining
option of MNNConvert during model conversion to retain the operators used in training such as BatchNorm and Dropout.
Second, use MNN to construct a model from scratch and use MNN for training, which can save the steps of model conversion and can also be easily converted into a model for quantization aware training. In the MNN_ROOT/tools/train/source/models/directory, we provide Lenet, MobilenetV1, MobilenetV2 and other models built using the MNN training framework.
1.1 Convert other framework models to MNN trainable models
Take the quantization-aware training of MobilenetV2 as an example. First, we need to download the MobilenetV2 model officially provided by TensorFlow, then compile MNNConvert, and execute the following command to convert:
./MNNConvert --modelFile mobilenet_v2_1.0_224_frozen.pb --MNNModel mobilenet_v2_tfpb_train.mnn --framework TF --bizCode AliNNTest --forTraining
Note that the preceding model used in the preceding command contains the BatchNorm operator and is not fused. By using the --forTraining
option during conversion, we keep the BatchNorm operator into the transformed epoch model.
If your model does not include operators like BatchNorm and Dropout which will be fused when converting the model to MNN model, you don’t have to convert your model again. In this situation, you can use the MNN inference model for training as well.
Next, let’s follow the example in MNN_ROOT/tools/train/source/demo/``mobilenetV2Train.cpp
which reads the converted model and converts it into a MNN training model. Key code examples are as follows:
// mobilenetV2Train.cpp
// Read the converted MNN model
auto varMap = Variable::loadMap(argv[1]);
// Specify bit width of the quantization
int bits = 8;
// Get intputs and outputs
auto inputOutputs = Variable::getInputAndOutput(varMap);
auto inputs = Variable::mapToSequence(inputOutputs.first);
auto outputs = Variable::mapToSequence(inputOutputs.second);
// Convert the MNN model into the trainable model.
// (Extract the Convolution, BatchNorm, Dropout from the inference model, and then convert to trainable modules)
std::shared_ptr<Module> model(PipelineModule::extract(inputs, outputs, true));
// Convert the trainable model into a quantization-training models. If you don't want to do
// quantization-aware training, then this step can be skipped.
((PipelineModule*)model.get())->toTrainQuant(bits);
// Train te model.
MobilenetV2Utils::train(model, 1001, 1, trainImagesFolder, trainImagesTxt, testImagesFolder, testImagesTxt);
1.2 Use MNN to construct models from scratch
Taking Lenet as an example, let’s take a look at how to use MNN to build a model from zero. MNN provides a variety of operators to use, the following example is not detailed. It is worth noting that the output of Pooling is NC4HW4 format, which needs to be converted to NCHW format to enter the full connection layer for calculation.
class MNN_PUBLIC Lenet : public Module {
public:
Lenet();
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override;
std::shared_ptr<Module> conv1;
std::shared_ptr<Module> conv2;
std::shared_ptr<Module> ip1;
std::shared_ptr<Module> ip2;
std::shared_ptr<Module> dropout;
};
// Initialization
Lenet::Lenet() {
NN::ConvOption convOption;
convOption.kernelSize = {5, 5};
convOption.channel = {1, 20};
conv1.reset(NN::Conv(convOption));
convOption.reset();
convOption.kernelSize = {5, 5};
convOption.channel = {20, 50};
conv2.reset(NN::Conv(convOption));
ip1.reset(NN::Linear(800, 500));
ip2.reset(NN::Linear(500, 10));
dropout.reset(NN::Dropout(0.5));
// You must register the parameters for them to be updated in back prop.
registerModel({conv1, conv2, ip1, ip2, dropout});
}
// Forward pass.
std::vector<Express::VARP> Lenet::onForward(const std::vector<Express::VARP>& inputs) {
using namespace Express;
VARP x = inputs[0];
x = conv1->forward(x);
x = _MaxPool(x, {2, 2}, {2, 2});
x = conv2->forward(x);
x = _MaxPool(x, {2, 2}, {2, 2});
// The output of the pooling layer is NC4HW4 which needs to be converted into
// NCHW before computation in the FC layer
x = _Convert(x, NCHW);
x = _Reshape(x, {0, -1});
x = ip1->forward(x);
x = _Relu(x);
x = dropout->forward(x);
x = ip2->forward(x);
x = _Softmax(x, 1);
return {x};
}
2 Implement the data set interface
See Load Training Data for a detailed description.
3 Train and save the model
Take MNIST model training as an example, the code is MNN_ROOT/tools/train/source/demo/`` MnistUtils.cpp
// MnistUtils.cpp
......
void MnistUtils::train(std::shared_ptr<Module> model, std::string root) {
{
// Load snapshot
// Model structure + model params
auto para = Variable::load("mnist.snapshot.mnn");
model->loadParameters(para);
}
// Configure training framework params.
auto exe = Executor::getGlobalExecutor();
BackendConfig config;
// Use CPU, 4 threads.
exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4);
// SGD optimizer.
std::shared_ptr<SGD> sgd(new SGD);
sgd->append(model->parameters());
// SGD params.
sgd->setMomentum(0.9f);
sgd->setWeightDecay(0.0005f);
// Creates the data set and DataLoader
auto dataset = MnistDataset::create(root, MnistDataset::Mode::TRAIN);
// the stack transform, stack [1, 28, 28] to [n, 1, 28, 28]
const size_t batchSize = 64;
const size_t numWorkers = 0;
bool shuffle = true;
auto dataLoader = std::shared_ptr<DataLoader>(dataset.createLoader(batchSize, true, shuffle, numWorkers));
size_t iterations = dataLoader->iterNumber();
auto testDataset = MnistDataset::create(root, MnistDataset::Mode::TEST);
const size_t testBatchSize = 20;
const size_t testNumWorkers = 0;
shuffle = false;
auto testDataLoader = std::shared_ptr<DataLoader>(testDataset.createLoader(testBatchSize, true, shuffle, testNumWorkers));
size_t testIterations = testDataLoader->iterNumber();
// Begin training
for (int epoch = 0; epoch < 50; ++epoch) {
model->clearCache();
exe->gc(Executor::FULL);
exe->resetProfile();
{
AUTOTIME;
dataLoader->reset();
// Set isTraining flag to be true during training phase.
model->setIsTraining(true);
Timer _100Time;
int lastIndex = 0;
int moveBatchSize = 0;
for (int i = 0; i < iterations; i++) {
// AUTOTIME;
// Obtain the training data and label for a batch.
auto trainData = dataLoader->next();
auto example = trainData[0];
auto cast = _Cast<float>(example.first[0]);
example.first[0] = cast * _Const(1.0f / 255.0f);
moveBatchSize += example.first[0]->getInfo()->dim[0];
// Compute One-Hot
auto newTarget = _OneHot(_Cast<int32_t>(example.second[0]), _Scalar<int>(10), _Scalar<float>(1.0f),
_Scalar<float>(0.0f));
// Forward pass
auto predict = model->forward(example.first[0]);
// Calculate loss
auto loss = _CrossEntropy(predict, newTarget);
// Adjust the learning rate
float rate = LrScheduler::inv(0.01, epoch * iterations + i, 0.0001, 0.75);
sgd->setLearningRate(rate);
if (moveBatchSize % (10 * batchSize) == 0 || i == iterations - 1) {
std::cout << "epoch: " << (epoch);
std::cout << " " << moveBatchSize << " / " << dataLoader->size();
std::cout << " loss: " << loss->readMap<float>()[0];
std::cout << " lr: " << rate;
std::cout << " time: " << (float)_100Time.durationInUs() / 1000.0f << " ms / " << (i - lastIndex) << " iter" << std::endl;
std::cout.flush();
_100Time.reset();
lastIndex = i;
}
// Backward pass and parameter updates.
sgd->step(loss);
}
}
// Saves the model parames for later reloads.
Variable::save(model->parameters(), "mnist.snapshot.mnn");
{
model->setIsTraining(false);
auto forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
forwardInput->setName("data");
auto predict = model->forward(forwardInput);
predict->setName("prob");
// Optimizes the network structure (optional)
Transformer::turnModelToInfer()->onExecute({predict});
// Saves the model and its structure, which can be used without the Module definitions.
Variable::save({predict}, "temp.mnist.mnn");
}
// Model test.
int correct = 0;
testDataLoader->reset();
// Set training to be false during model test.
model->setIsTraining(false);
int moveBatchSize = 0;
for (int i = 0; i < testIterations; i++) {
auto data = testDataLoader->next();
auto example = data[0];
moveBatchSize += example.first[0]->getInfo()->dim[0];
if ((i + 1) % 100 == 0) {
std::cout << "test: " << moveBatchSize << " / " << testDataLoader->size() << std::endl;
}
auto cast = _Cast<float>(example.first[0]);
example.first[0] = cast * _Const(1.0f / 255.0f);
auto predict = model->forward(example.first[0]);
predict = _ArgMax(predict, 1);
auto accu = _Cast<int32_t>(_Equal(predict, _Cast<int32_t>(example.second[0]))).sum({});
correct += accu->readMap<int32_t>()[0];
}
// Calculate the accuracy.
auto accu = (float)correct / (float)testDataLoader->size();
std::cout << "epoch: " << epoch << " accuracy: " << accu << std::endl;
exe->dumpProfile();
}
}
4 Save and restore the model
There are two ways to save the model:
First, only the model parameters are saved, not the model structure. These parameters need to be loaded by the corresponding model structure.
Save:
Variable::save(model->parameters(), "mnist.snapshot.mnn");
Restore:
// Model structure + params
auto para = Variable::load("mnist.snapshot.mnn");
model->loadParameters(para);
Second, save the model structure and parameters at the same time for easy inference.
Save:
model->setIsTraining(false);
auto forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
forwardInput->setName("data");
auto predict = model->forward(forwardInput);
predict->setName("prob");
// Saves the output node, and saves the structure + params
Variable::save({predict}, "temp.mnist.mnn");
Restore (Inference):
auto varMap = Variable::loadMap("temp.mnist.mnn");
// The input node name is the same as the one defined during Save, i.e. 'data'.
// The dimension is also the same as the one defined during Save, ie. [1, 1, 28, 28]
float* inputPtr = varMap["data"]->writeMap<float>();
//Fill inputPtr
// The output node name is the same as the one defined during Save, i.e. 'prob'
float* outputPtr = varMap["prob"]->readMap<float>();
// Use data in outputPtr