为什么要自己写个垃圾库

在CIFAR平台上,以ResNet为基础网络,我做了很多的实验。包括BasicBlock、Bottleneck、宽网络、dropblock、学习率和L2正则化的影响、学习率调控策略、mixup等等等等……
官方的pytorch库,没有在具体使用时提供高层封装,因此每次要做个什么实验就只能魔改resnet.py,把库的py文件改的一塌糊涂,而且在main.py里要写非常非常多的东西,包括用了什么网络、传什么参数、是否要使用tensorboard、学习率调控策略等等。
有的时候跑完了都不记得自己是用的什么config跑的,之后想复现了跑不出原来的结果,折腾很久才发现原来的config写错了(logname写的宽度是64但实际上宽度是96)

于是我就想,能不能自己写个简单的库,把我所学习到的方法都整合进去,使用的时候只需要编写config就能进行训练,完全将内部的库代码封装并做好版本管理。每次实验完的时候,以logname命名记录当前训练的config参数,后期想复现直接复制一份config就可以。

这个库包含了什么

更新到v1.0版本
*目前该库只用于分类网络,不支持目标检测和图像分割
_
每个模块的具体config,请参见config目录下的文件,其中config/config.py给出了一个完整的config的构成及每个key对应的部件。

模型

支持:ResNet、Res2Net、Mnist
待实现:ResNext、DenseNet、MobileNet、ShuffleNet
支持的激活函数:ReLU、LeakyReLU、ReLU6、Sigmoid、Tanh、hard-Swish、Swish、Mish、ELU

其中ResNet与Res2Net支持CIFAR形33232输入,也支持ImageNet形3128128~3320320输入,根据输入大小的不同,模型会自动构建层以适配输入(譬如输入3232,自动将ResNet的conv1层从k7s2改为k3s1并去除maxpool)
其中Mnist只接受1
28*28的Mnist输入。

优化器

支持:SGD、Adam
完全调用了pytorch的optimizer,只是提供了封装。因为偷懒只写了这两种。

学习率调控

支持:多阶段(砍几刀法)、余弦衰减、余弦周期、余弦退火、指数衰减、loss不降自动砍学习率

训练状态监视

可以指定是否使用TensorBoard writer进行记录,可以指定train与test的记录密度(一个epoch记几次),记录test结果会触发一次test_set的前向推理。

checkpoint

可以指定是否在训练过程中间隔一定epoch数进行checkpoint,可以指定checkpoint的存储内容(如model、optimizer的动量、scheduler的信息等)。
可以指定当test_set的准确度达到一定数值以上时存储当前model。

训练技巧

支持:DropBlock、Mixup、SmoothLabel
待实现:SEBlock
**
DropBlock其实github上面有源码,但是这份源码写的比较垃圾而且和论文中所描述的相距甚远,于是自己重写了一份。可以指定drop_rate,支持线性drop_rate schedule,支持自定义drop block_size,支持自定义是否在通道间共享drop mask,支持自定义DropBlock需要作用于哪些stage。
Mixup其实实现起来非常方便,但是能带来非常巨大的提升。
SmoothLabel支持自定义eps。

使用说明

在顶层目录构建main.py:

  1. import torchvision
  2. from torchvision import datasets, transforms
  3. from src.runner import *
  4. cfg = {...}
  5. transform_train = transforms.Compose([...])
  6. transform_test = transforms.Compose([...])
  7. train_dataset, test_dataset = ...
  8. if __name__ == '__main__':
  9. runner = Runner(cfg, train_dataset, test_dataset)
  10. runner.run() # 进行训练
  11. runner.run(build_from_checkpoint=True, checkpoint_file='...', use_latest=False)
  12. # 从checkpoint继续训练
  13. # runner.train_benchmark() # 进行训练速度测试
  14. # runner.test_benchmark() # 待实现,因为现在没板子,测推理速度没意义

踩得坑

dict怎么写文件

以下代码会爆炸,因为f.write的对象必须是str类型,而我们传了个dict进去。

  1. cfg = {...}
  2. f = open(path, 'w')
  3. f.write(cfg)
  4. f.close()

百度了一下以后,接触到了神奇的传说中的json文件,郑老师之前也说了label和config很多都通过json写,毕竟刚接触到,之后碰到了可以参考。

  1. import json
  2. cfg = {...}
  3. f = open(path, 'w')
  4. f.write(json.dumps(cfg, indent=4, ensure_ascii=False))
  5. f.close()

ResNet的downsample也要加BN

不然直接掉1个点。