Base
_end表示在某个之后运行,*:training_setp、validation_step、training_epoch …
def training_step(self, batch, batch_idx):
...
return {'loss': loss, 'pred': pred}
def training_step_end(self, batch_parts):
'''
当gpus=0 or 1时,这里的batch_parts即为traing_step的返回值
当gpus>1时,这里的batch_parts为list,list中每个为training_step返回值,list[i]为i号gpu的返回值
'''
gpu_0_prediction = batch_parts[0]['pred']
gpu_1_prediction = batch_parts[1]['pred']
# do something with both outputs
return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
def training_epoch_end(self, training_step_outputs):
'''
当gpu=0 or 1时,training_step_outputs为list,长度为steps的数量
'''
for out in training_step_outputs:
# do something with preds
Train
基本就是重写training_step,记得至少返回一个loss。如果需要其他操作,比如在step,epoch等完成后的操作,则根据需要添加返回值即可。
self.current_epoch可以获取当前的epoch进行一些更精细的操作
Val
实现了val,就会在训练之前有sainty check确保val可用。
- 如果需要,可以val设置频率,比如是每多少个training_epoch进行一次,或者每多少个training_step进行一次。
```python
每个epoch一次
trainer = Trainer(check_val_every_n_epoch=1)
每个epoch的 25% step调用val一次,比如1个epoch有100个step
trainer = Trainer(val_check_interval=0.25)
每个epoch中的多少个batch一次
trainer = Trainer(val_check_interval=100) # 每训练100个batch校验一次
- 与Base类似,如果需要对val的结果进行统计,可以在vallidation_epoch_end中对val的返回值进行收集,分析等操作
<a name="Rbvvw"></a>
## Optimizer
- 需要注意的是如果使用schedule,step和pytorch lightning的step不一致,step_size指的是epoch
```python
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.params['lr'], weight_decay= self.params['weight_decay'])
schedule = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
return [optimizer], [schedule]