SGD with momentum
Usage Example:
// Constructs an SGD optimizer.std::shared_ptr<SGD> solver(new SGD);// Sets the params to be optimized in the model.solver->append(model->parameters());// Sets momentum and weight decaysolver->setMomentum(0.9f);solver->setWeightDecay(0.0005f);// Sets the regularization method, defaults to L2 norm.solver->setRegularizationMethod(RegularizationMethod::L2);// Sets the learning rate.solver->setLearningRate(0.001);// Calculates the gradient based on the loss and updates the params.solver->step(loss);
ADAM
Usage examples
// Constructs an ADAM optimizerstd::shared_ptr<SGD> solver(new ADAM);// Sets the params to be optimized in the model.solver->append(model->parameters());// Sets ADAM momentums,weight and decaysolver->setMomentum(0.9f);solver->setMomentum2(0.99f);solver->setWeightDecay(0.0005f);// Sets the regularization method, defaults to L2 norm.solver->setRegularizationMethod(RegularizationMethod::L2);// Sets the learning rate.solver->setLearningRate(0.001);// Calculates the gradient based on the loss and updates the params.solver->step(loss);
Loss
Loss functions supported right now. You can also define your own.
VARP _CrossEntropy(Express::VARP predicts, Express::VARP oneHotTargets);VARP _KLDivergence(Express::VARP predicts, Express::VARP oneHotTargets);VARP _MSE(Express::VARP predicts, Express::VARP oneHotTargets);VARP _MAE(Express::VARP predicts, Express::VARP oneHotTargets);VARP _Hinge(Express::VARP predicts, Express::VARP oneHotTargets);VARP _DistillLoss(Express::VARP studentLogits, Express::VARP teacherLogits, Express::VARP oneHotTargets,const float temperature, const float alpha);
