SGD with momentum

使用示例

  1. // 新建SGD优化器
  2. std::shared_ptr<SGD> solver(new SGD);
  3. // 设置模型中需要优化的参数
  4. solver->append(model->parameters());
  5. // 设置momentum和weight decay
  6. solver->setMomentum(0.9f);
  7. solver->setWeightDecay(0.0005f);
  8. // 设置正则化方法,默认L2
  9. solver->setRegularizationMethod(RegularizationMethod::L2);
  10. // 设置学习率
  11. solver->setLearningRate(0.001);
  12. // 根据loss计算梯度,并更新参数
  13. solver->step(loss);

ADAM

使用示例

  1. // 新建ADAM优化器
  2. std::shared_ptr<SGD> solver(new ADAM);
  3. // 设置模型中需要优化的参数
  4. solver->append(model->parameters());
  5. // 设置ADAM的两个momentum,设置weight decay
  6. solver->setMomentum(0.9f);
  7. solver->setMomentum2(0.99f);
  8. solver->setWeightDecay(0.0005f);
  9. // 设置正则化方法,默认L2
  10. solver->setRegularizationMethod(RegularizationMethod::L2);
  11. // 设置学习率
  12. solver->setLearningRate(0.001);
  13. // 根据loss计算梯度,并更新参数
  14. solver->step(loss);

Loss

目前支持的Loss,也可自行设计

  1. VARP _CrossEntropy(Express::VARP predicts, Express::VARP oneHotTargets);
  2. VARP _KLDivergence(Express::VARP predicts, Express::VARP oneHotTargets);
  3. VARP _MSE(Express::VARP predicts, Express::VARP oneHotTargets);
  4. VARP _MAE(Express::VARP predicts, Express::VARP oneHotTargets);
  5. VARP _Hinge(Express::VARP predicts, Express::VARP oneHotTargets);
  6. VARP _DistillLoss(Express::VARP studentLogits, Express::VARP teacherLogits, Express::VARP oneHotTargets,
  7. const float temperature, const float alpha);