Base

_end表示在某个之后运行,*:training_setp、validation_step、training_epoch …

  1. def training_step(self, batch, batch_idx):
  2. ...
  3. return {'loss': loss, 'pred': pred}
  4. def training_step_end(self, batch_parts):
  5. '''
  6. 当gpus=0 or 1时,这里的batch_parts即为traing_step的返回值
  7. 当gpus>1时,这里的batch_parts为list,list中每个为training_step返回值,list[i]为i号gpu的返回值
  8. '''
  9. gpu_0_prediction = batch_parts[0]['pred']
  10. gpu_1_prediction = batch_parts[1]['pred']
  11. # do something with both outputs
  12. return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
  13. def training_epoch_end(self, training_step_outputs):
  14. '''
  15. 当gpu=0 or 1时,training_step_outputs为list,长度为steps的数量
  16. '''
  17. for out in training_step_outputs:
  18. # do something with preds

Train

基本就是重写training_step,记得至少返回一个loss。如果需要其他操作,比如在step,epoch等完成后的操作,则根据需要添加返回值即可。

  • self.current_epoch可以获取当前的epoch进行一些更精细的操作

    Val

  • 实现了val,就会在训练之前有sainty check确保val可用。

  • 如果需要,可以val设置频率,比如是每多少个training_epoch进行一次,或者每多少个training_step进行一次。 ```python

    每个epoch一次

    trainer = Trainer(check_val_every_n_epoch=1)

每个epoch的 25% step调用val一次,比如1个epoch有100个step

trainer = Trainer(val_check_interval=0.25)

每个epoch中的多少个batch一次

trainer = Trainer(val_check_interval=100) # 每训练100个batch校验一次


- 与Base类似,如果需要对val的结果进行统计,可以在vallidation_epoch_end中对val的返回值进行收集,分析等操作
<a name="Rbvvw"></a>
## Optimizer

- 需要注意的是如果使用schedule,step和pytorch lightning的step不一致,step_size指的是epoch
```python
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.params['lr'], weight_decay= self.params['weight_decay'])
        schedule = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
        return [optimizer], [schedule]