SGD with momentum
Usage Example:
// Constructs an SGD optimizer.
std::shared_ptr<SGD> solver(new SGD);
// Sets the params to be optimized in the model.
solver->append(model->parameters());
// Sets momentum and weight decay
solver->setMomentum(0.9f);
solver->setWeightDecay(0.0005f);
// Sets the regularization method, defaults to L2 norm.
solver->setRegularizationMethod(RegularizationMethod::L2);
// Sets the learning rate.
solver->setLearningRate(0.001);
// Calculates the gradient based on the loss and updates the params.
solver->step(loss);
ADAM
Usage examples
// Constructs an ADAM optimizer
std::shared_ptr<SGD> solver(new ADAM);
// Sets the params to be optimized in the model.
solver->append(model->parameters());
// Sets ADAM momentums,weight and decay
solver->setMomentum(0.9f);
solver->setMomentum2(0.99f);
solver->setWeightDecay(0.0005f);
// Sets the regularization method, defaults to L2 norm.
solver->setRegularizationMethod(RegularizationMethod::L2);
// Sets the learning rate.
solver->setLearningRate(0.001);
// Calculates the gradient based on the loss and updates the params.
solver->step(loss);
Loss
Loss functions supported right now. You can also define your own.
VARP _CrossEntropy(Express::VARP predicts, Express::VARP oneHotTargets);
VARP _KLDivergence(Express::VARP predicts, Express::VARP oneHotTargets);
VARP _MSE(Express::VARP predicts, Express::VARP oneHotTargets);
VARP _MAE(Express::VARP predicts, Express::VARP oneHotTargets);
VARP _Hinge(Express::VARP predicts, Express::VARP oneHotTargets);
VARP _DistillLoss(Express::VARP studentLogits, Express::VARP teacherLogits, Express::VARP oneHotTargets,
const float temperature, const float alpha);