为什么要自己写个垃圾库
在CIFAR平台上,以ResNet为基础网络,我做了很多的实验。包括BasicBlock、Bottleneck、宽网络、dropblock、学习率和L2正则化的影响、学习率调控策略、mixup等等等等……
官方的pytorch库,没有在具体使用时提供高层封装,因此每次要做个什么实验就只能魔改resnet.py,把库的py文件改的一塌糊涂,而且在main.py里要写非常非常多的东西,包括用了什么网络、传什么参数、是否要使用tensorboard、学习率调控策略等等。
有的时候跑完了都不记得自己是用的什么config跑的,之后想复现了跑不出原来的结果,折腾很久才发现原来的config写错了(logname写的宽度是64但实际上宽度是96)
于是我就想,能不能自己写个简单的库,把我所学习到的方法都整合进去,使用的时候只需要编写config就能进行训练,完全将内部的库代码封装并做好版本管理。每次实验完的时候,以logname命名记录当前训练的config参数,后期想复现直接复制一份config就可以。
这个库包含了什么
更新到v1.0版本
*目前该库只用于分类网络,不支持目标检测和图像分割
_
每个模块的具体config,请参见config目录下的文件,其中config/config.py给出了一个完整的config的构成及每个key对应的部件。
模型
支持:ResNet、Res2Net、Mnist
待实现:ResNext、DenseNet、MobileNet、ShuffleNet
支持的激活函数:ReLU、LeakyReLU、ReLU6、Sigmoid、Tanh、hard-Swish、Swish、Mish、ELU
其中ResNet与Res2Net支持CIFAR形33232输入,也支持ImageNet形3128128~3320320输入,根据输入大小的不同,模型会自动构建层以适配输入(譬如输入3232,自动将ResNet的conv1层从k7s2改为k3s1并去除maxpool)
其中Mnist只接受128*28的Mnist输入。
优化器
支持:SGD、Adam
完全调用了pytorch的optimizer,只是提供了封装。因为偷懒只写了这两种。
学习率调控
支持:多阶段(砍几刀法)、余弦衰减、余弦周期、余弦退火、指数衰减、loss不降自动砍学习率
训练状态监视
可以指定是否使用TensorBoard writer进行记录,可以指定train与test的记录密度(一个epoch记几次),记录test结果会触发一次test_set的前向推理。
checkpoint
可以指定是否在训练过程中间隔一定epoch数进行checkpoint,可以指定checkpoint的存储内容(如model、optimizer的动量、scheduler的信息等)。
可以指定当test_set的准确度达到一定数值以上时存储当前model。
训练技巧
支持:DropBlock、Mixup、SmoothLabel
待实现:SEBlock
**
DropBlock其实github上面有源码,但是这份源码写的比较垃圾而且和论文中所描述的相距甚远,于是自己重写了一份。可以指定drop_rate,支持线性drop_rate schedule,支持自定义drop block_size,支持自定义是否在通道间共享drop mask,支持自定义DropBlock需要作用于哪些stage。
Mixup其实实现起来非常方便,但是能带来非常巨大的提升。
SmoothLabel支持自定义eps。
使用说明
在顶层目录构建main.py:
import torchvision
from torchvision import datasets, transforms
from src.runner import *
cfg = {...}
transform_train = transforms.Compose([...])
transform_test = transforms.Compose([...])
train_dataset, test_dataset = ...
if __name__ == '__main__':
runner = Runner(cfg, train_dataset, test_dataset)
runner.run() # 进行训练
runner.run(build_from_checkpoint=True, checkpoint_file='...', use_latest=False)
# 从checkpoint继续训练
# runner.train_benchmark() # 进行训练速度测试
# runner.test_benchmark() # 待实现,因为现在没板子,测推理速度没意义
踩得坑
dict怎么写文件
以下代码会爆炸,因为f.write的对象必须是str类型,而我们传了个dict进去。
cfg = {...}
f = open(path, 'w')
f.write(cfg)
f.close()
百度了一下以后,接触到了神奇的传说中的json文件,郑老师之前也说了label和config很多都通过json写,毕竟刚接触到,之后碰到了可以参考。
import json
cfg = {...}
f = open(path, 'w')
f.write(json.dumps(cfg, indent=4, ensure_ascii=False))
f.close()
ResNet的downsample也要加BN
不然直接掉1个点。