1. PP-OCR系统简介与总览
前两章主要介绍了DBNet文字检测算法以及CRNN文字识别算法。然而对于我们实际场景中的一张图像,想要单独基于文字检测或者识别模型,是无法同时获取文字位置与文字内容的,因此,我们将文字检测算法以及文字识别算法进行串联,构建了PP-OCR文字检测与识别系统。在实际使用过程中,检测出的文字方向可能不是我们期望的方向,最终导致文字识别错误,因此我们在PP-OCR系统中也引入了方向分类器。
本章主要介绍PP-OCR文字检测与识别系统以及该系统中涉及到的优化策略。通过本节课的学习,您可以获得:
- PaddleOCR策略调优技巧
- 文本检测、识别、方向分类器模型的优化技巧和优化方法
PP-OCR系统共经历了2次优化,下面对PP-OCR系统和这2次优化进行简单介绍。
1.1 PP-OCR系统与优化策略简介
PP-OCR中,对于一张图像,如果希望提取其中的文字信息,需要完成以下几个步骤:
- 使用文本检测的方法,获取文本区域多边形信息(PP-OCR中文本检测使用的是DBNet,因此获取的是四点信息)。
- 对上述文本多边形区域进行裁剪与透视变换校正,将文本区域转化成矩形框,再使用方向分类器对方向进行校正。
- 基于包含文字区域的矩形框进行文本识别,得到最终识别结果。
上面便完成了对于一张图像的文本检测与识别过程。
PP-OCR的系统框图如下所示。
PP-OCR系统框图
文本检测基于后处理方案比较简单的DBNet,文字区域校正主要使用几何变换以及方向分类器,文本识别使用了基于融合了卷积特征与序列特征的CRNN模型,使用CTC loss解决预测结果与标签不一致的问题。
PP-OCR从骨干网络、学习率策略、数据增广、模型裁剪量化等方面,共使用了19个策略,对模型进行优化瘦身,最终打造了面向服务器端的PP-OCR server系统以及面向移动端的PP-OCR mobile系统。
1.2 PP-OCRv2系统与优化策略简介
相比于PP-OCR, PP-OCRv2 在骨干网络、数据增广、损失函数这三个方面进行进一步优化,解决端侧预测效率较差、背景复杂以及相似字符的误识等问题,同时引入了知识蒸馏训练策略,进一步提升模型精度。具体地:
- 检测模型优化: (1) 采用 CML 协同互学习知识蒸馏策略;(2) CopyPaste 数据增广策略;
- 识别模型优化: (1) PP-LCNet 轻量级骨干网络;(2) U-DML 改进知识蒸馏策略; (3) Enhanced CTC loss 损失函数改进。
从效果上看,主要有三个方面提升:
- 在模型效果上,相对于 PP-OCR mobile 版本提升超7%;
- 在速度上,相对于 PP-OCR server 版本提升超过220%;
- 在模型大小上,11.6M 的总大小,服务器端和移动端都可以轻松部署。
PP-OCRv2 模型与之前 PP-OCR 系列模型的精度、预测耗时、模型大小对比图如下所示。
PP-OCRv2与PP-OCR的速度、精度、模型大小对比
PP-OCRv2的系统框图如下所示。
PP-OCRv2系统框图
本章将对上述PP-OCR以及PP-OCRv2系统优化策略进行详细的解读。
2. PP-OCR 优化策略
PP-OCR系统包括文本检测器、方向分类器以及文本识别器。本节针对这三个方向的模型优化策略进行详细介绍。
2.1 文本检测
PP-OCR中的文本检测基于DBNet (Differentiable Binarization)模型,它基于分割方案,后处理简单。DBNet的具体模型结构如下图。
DBNet框图
DBNet通过骨干网络(backbone)提取特征,使用DBFPN的结构(neck)对各阶段的特征进行融合,得到融合后的特征。融合后的特征经过卷积等操作(head)进行解码,生成概率图和阈值图,二者融合后计算得到一个近似的二值图。计算损失函数时,对这三个特征图均计算损失函数,这里把二值化的监督也也加入训练过程,从而让模型学习到更准确的边界。
DBNet中使用了6种优化策略用于提升模型精度与速度,包括骨干网络、特征金字塔网络、头部结构、学习率策略、模型裁剪等策略。在验证集上,不同模块的消融实验结论如下所示。
下面进行详细说明。
2.1.1 轻量级骨干网络
骨干网络的大小对文本检测器的模型大小有重要影响。因此,在构建超轻量检测模型时,应选择轻量的骨干网络。随着图像分类技术的发展,MobileNetV1、MobileNetV2、MobileNetV3和ShuffleNetV2系列常用作轻量骨干网络。每个系列都有不同的模型大小和性能表现。PaddeClas提供了20多种轻量级骨干网络。他们在ARM上的精度-速度
曲线如下图所示。
在预测时间相同的情况下,MobileNetV3系列可以实现更高的精度。作者在设计的时候为了覆盖尽可能多的场景,使用scale这个参数来调整特征图通道数,标准为1x,如果是0.5x,则表示该网络中部分特征图通道数为1x对应网络的0.5倍。为了进一步平衡准确率和效率,在V3的尺寸选择上,我们采用了MobileNetV3_large 0.5x的结构。
下面打印出DBNet中MobileNetV3各个阶段的特征图尺寸。
import os
import sys
# 下载代码
os.chdir("/home/aistudio/")
!git clone https://gitee.com/paddlepaddle/PaddleOCR.git
# 切换工作目录
os.chdir("/home/aistudio/PaddleOCR/")
!pip install -U pip
!pip install -r requirements.txt
fatal: destination path 'PaddleOCR' already exists and is not an empty directory.
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting pip
[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a4/6d/6463d49a933f547439d6b5b98b46af8742cc03ae83543e4d7688c2420f8b/pip-21.3.1-py3-none-any.whl (1.7MB)
[K |████████████████████████████████| 1.7MB 16.9MB/s eta 0:00:01
[?25hInstalling collected packages: pip
Found existing installation: pip 19.2.3
Uninstalling pip-19.2.3:
Successfully uninstalled pip-19.2.3
Successfully installed pip-21.3.1
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting shapely
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/20/33ce377bd24d122a4d54e22ae2c445b9b1be8240edb50040b40add950cd9/Shapely-1.8.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)
|████████████████████████████████| 1.1 MB 10.4 MB/s
[?25hCollecting scikit-image
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/44/8f8c7f9c9de7fde70587a656d7df7d056e6f05192a74491f7bc074a724d0/scikit_image-0.19.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.3 MB)
|████████████████████████████████| 13.3 MB 12.8 MB/s
[?25hCollecting imgaug==0.4.0
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
|████████████████████████████████| 948 kB 12.5 MB/s
[?25hCollecting pyclipper
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c5/fa/2c294127e4f88967149a68ad5b3e43636e94e3721109572f8f17ab15b772/pyclipper-1.3.0.post2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (603 kB)
|████████████████████████████████| 603 kB 7.0 MB/s
[?25hCollecting lmdb
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2e/dd/ada2fd91cd7832979069c556607903f274470c3d3d2274e0a848908272e8/lmdb-1.2.1-cp37-cp37m-manylinux2010_x86_64.whl (299 kB)
|████████████████████████████████| 299 kB 14.7 MB/s
[?25hRequirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 6)) (4.27.0)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (1.20.3)
Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 8)) (2.2.0)
Collecting python-Levenshtein
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/dc/97f2b63ef0fa1fd78dcb7195aca577804f6b2b51e712516cc0e902a9a201/python-Levenshtein-0.12.2.tar.gz (50 kB)
|████████████████████████████████| 50 kB 4.6 MB/s
[?25h Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting opencv-contrib-python==4.4.0.46
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/08/51/1e0a206dd5c70fea91084e6f43979dc13e8eb175760cc7a105083ec3eb68/opencv_contrib_python-4.4.0.46-cp37-cp37m-manylinux2014_x86_64.whl (55.7 MB)
|████████████████████████████████| 55.7 MB 60.7 MB/s
[?25hRequirement already satisfied: cython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 11)) (0.29)
Collecting lxml
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7b/01/16a9b80c8ce4339294bb944f08e157dbfcfbb09ba9031bde4ddf7e3e5499/lxml-4.7.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (6.4 MB)
|████████████████████████████████| 6.4 MB 61.9 MB/s
[?25hCollecting premailer
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b1/07/4e8d94f94c7d41ca5ddf8a9695ad87b888104e2fd41a35546c1dc9ca74ac/premailer-3.10.0-py2.py3-none-any.whl (19 kB)
Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 14)) (3.0.5)
Collecting fasttext==0.9.1
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/10/61/2e01f1397ec533756c1d893c22d9d5ed3fce3a6e4af1976e0d86bb13ea97/fasttext-0.9.1.tar.gz (57 kB)
|████████████████████████████████| 57 kB 21.6 MB/s
[?25h Preparing metadata (setup.py) ... [?25ldone
[?25hRequirement already satisfied: imageio in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (2.6.1)
Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (2.2.3)
Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.15.0)
Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.6.3)
Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (7.1.2)
Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (4.1.1.26)
Collecting pybind11>=2.2
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/a8/3b/fc246e1d4c7547a7a07df830128e93c6215e9b93dcb118b2a47a70726153/pybind11-2.8.1-py2.py3-none-any.whl (208 kB)
Requirement already satisfied: setuptools>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fasttext==0.9.1->-r requirements.txt (line 15)) (56.2.0)
Collecting tifffile>=2019.7.26
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d8/38/85ae5ed77598ca90558c17a2f79ddaba33173b31cf8d8f545d34d9134f0d/tifffile-2021.11.2-py3-none-any.whl (178 kB)
|████████████████████████████████| 178 kB 83.4 MB/s
[?25hRequirement already satisfied: packaging>=20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 2)) (20.9)
Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 2)) (2.4)
Collecting PyWavelets>=1.1.1
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a1/9c/564511b6e1c4e1d835ed2d146670436036960d09339a8fa2921fe42dad08/PyWavelets-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (6.1 MB)
|████████████████████████████████| 6.1 MB 2.0 MB/s
[?25hRequirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.5)
Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.1)
Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.0.0)
Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.14.0)
Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.8.53)
Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.21.0)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (2.22.0)
Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.8.2)
Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.7.1.1)
Collecting cssselect
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3b/d4/3b5c17f00cce85b9a1e6f91096e1cc8e8ede2e1be8e96b87ce1ed09e92c5/cssselect-1.1.0-py2.py3-none-any.whl (16 kB)
Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 13)) (4.0.0)
Collecting cssutils
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/24/c4/9db28fe567612896d360ab28ad02ee8ae107d0e92a22db39affd3fba6212/cssutils-2.3.0-py3-none-any.whl (404 kB)
|████████████████████████████████| 404 kB 8.3 MB/s
[?25hRequirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 14)) (1.4.1)
Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 14)) (1.0.1)
Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.23)
Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.6.1)
Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.2.0)
Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.6.0)
Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (7.0)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.0)
Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (2.11.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (0.16.0)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2.8.0)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2019.3)
Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.2->scikit-image->-r requirements.txt (line 2)) (4.4.2)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from packaging>=20.0->scikit-image->-r requirements.txt (line 2)) (2.4.2)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (3.9.9)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (0.18.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->-r requirements.txt (line 3)) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->-r requirements.txt (line 3)) (0.10.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->-r requirements.txt (line 3)) (2.8.0)
Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (5.1.2)
Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (0.10.0)
Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (16.7.9)
Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.4)
Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.4.10)
Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (2.0.1)
Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.0)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2.8)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2019.9.11)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (1.25.6)
Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.1)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (3.6.0)
Building wheels for collected packages: fasttext, python-Levenshtein
Building wheel for fasttext (setup.py) ... [?25ldone
[?25h Created wheel for fasttext: filename=fasttext-0.9.1-cp37-cp37m-linux_x86_64.whl size=2586036 sha256=e4556948fc3c908fe6fae649e62d5617abb5d050d223e30ad395bcb6db4e96fe
Stored in directory: /home/aistudio/.cache/pip/wheels/a1/cb/b3/a25a8ce16c1a4ff102c1e40d6eaa4dfc9d5695b92d57331b36
Building wheel for python-Levenshtein (setup.py) ... [?25ldone
[?25h Created wheel for python-Levenshtein: filename=python_Levenshtein-0.12.2-cp37-cp37m-linux_x86_64.whl size=171683 sha256=eee8be1973b2e1717e541ec1f91155bca78d08512cf9d0f269064dd7b65c00bb
Stored in directory: /home/aistudio/.cache/pip/wheels/38/b9/a4/3729726160fb103833de468adb5ce019b58543ae41d0b0e446
Successfully built fasttext python-Levenshtein
Installing collected packages: tifffile, PyWavelets, shapely, scikit-image, pybind11, lxml, cssutils, cssselect, python-Levenshtein, pyclipper, premailer, opencv-contrib-python, lmdb, imgaug, fasttext
Successfully installed PyWavelets-1.2.0 cssselect-1.1.0 cssutils-2.3.0 fasttext-0.9.1 imgaug-0.4.0 lmdb-1.2.1 lxml-4.7.1 opencv-contrib-python-4.4.0.46 premailer-3.10.0 pybind11-2.8.1 pyclipper-1.3.0.post2 python-Levenshtein-0.12.2 scikit-image-0.19.1 shapely-1.8.0 tifffile-2021.11.2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 具体代码实现位于:
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/backbones/det_mobilenet_v3.py
import numpy as np
import paddle
# 设置随机输入
inputs = np.random.rand(1, 3, 640, 640).astype(np.float32)
x = paddle.to_tensor(inputs)
# 导入MobileNetV3库
from ppocr.modeling.backbones.det_mobilenet_v3 import MobileNetV3
# 模型定义
backbone_mv3 = MobileNetV3(scale=0.5, model_name='large')
# 模型forward
bk_out = backbone_mv3(x)
# 模型中间层打印
for i, stage_out in enumerate(bk_out):
print("the shape of ",i,'stage: ',stage_out.shape)
1
2
3
4
the shape of 0 stage: [1, 16, 160, 160]
the shape of 1 stage: [1, 24, 80, 80]
the shape of 2 stage: [1, 56, 40, 40]
the shape of 3 stage: [1, 480, 20, 20]
2.1.2 轻量级特征金字塔网络DBFPN结构
文本检测器的特征融合(neck)部分DBFPN与目标检测任务中的FPN结构类似,融合不同尺度的特征图,以提升不同尺度的文本区域检测效果。
为了方便合并不同通道的特征图,这里使用1×1
的卷积将特征图减少到相同数量的通道。
概率图和阈值图是由卷积融合的特征图生成的,卷积也与inner_channels相关联。因此,inner_channels对模型尺寸有很大的影响。当inner_channels由256减小到96时,模型尺寸由7M减小到4.1M,速度提升48%,但精度只是略有下降。
下面打印DBFPN的结构以及对于骨干网络特征图的融合结果。
1
2
3
4
5
6
7
8
9
10
11
12
# 具体代码实现位于:
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/necks/db_fpn.py
from ppocr.modeling.necks.db_fpn import DBFPN
neck_bdfpn = DBFPN(in_channels=[16, 24, 56, 480], out_channels=96)
# 打印 DBFPN结构
print(neck_bdfpn)
# 先对原始的通道数降到96,再降到24,最后4个feature map进行concat
fpn_out = neck_bdfpn(bk_out)
print('the shape of output of DBFPN: ', fpn_out.shape)
1
2
3
4
5
6
7
8
9
10
11
DBFPN(
(in2_conv): Conv2D(16, 96, kernel_size=[1, 1], data_format=NCHW)
(in3_conv): Conv2D(24, 96, kernel_size=[1, 1], data_format=NCHW)
(in4_conv): Conv2D(56, 96, kernel_size=[1, 1], data_format=NCHW)
(in5_conv): Conv2D(480, 96, kernel_size=[1, 1], data_format=NCHW)
(p5_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p4_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p3_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p2_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
)
the shape of output of DBFPN: [1, 96, 160, 160]
2.1.3 骨干网络中SE模块分析
SE是squeeze-and-excitation
的缩写(Hu, Shen, and Sun 2018)。如图所示
SE块显式地建模通道之间的相互依赖关系,并自适应地重新校准通道特征响应。在网络中使用SE块可以明显提高视觉任务的准确性,因此MobileNetV3的搜索空间包含了SE模块,最终MobileNetV3中也包含很多个SE模块。然而,当输入分辨率较大时,例如640×640
,使用SE模块较难估计通道的特征响应,精度提高有限,但SE模块的时间成本非常高。在DBNet中,我们将SE模块从骨干网络中移除,模型大小从4.1M
降到2.6M
,但精度没有影响。
PaddleOCR中可以通过设置disable_se=True
来移除骨干网络中的SE模块,使用方法如下所示。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 具体代码实现位于:
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/backbones/det_mobilenet_v3.py
x = paddle.rand([1, 3, 640, 640])
from ppocr.modeling.backbones.det_mobilenet_v3 import MobileNetV3
# 定义模型
backbone_mv3 = MobileNetV3(scale=0.5, model_name='large', disable_se=True)
# 模型forward
bk_out = backbone_mv3(x)
# 输出
for i, stage_out in enumerate(bk_out):
print("the shape of ",i,'stage: ',stage_out.shape)
1
2
3
4
the shape of 0 stage: [1, 16, 160, 160]
the shape of 1 stage: [1, 24, 80, 80]
the shape of 2 stage: [1, 56, 40, 40]
the shape of 3 stage: [1, 480, 20, 20]
2.1.4 学习率策略优化
- Cosine 学习率下降策略
梯度下降算法需要我们设置一个值,用来控制权重更新幅度,我们将其称之为学习率。它是控制模型学习速度的超参数。学习率越小,loss的变化越慢。虽然使用较低的学习速率可以确保不会错过任何局部极小值,但这也意味着模型收敛速度较慢。
因此,在训练前期,权重处于随机初始化状态,我们可以设置一个相对较大的学习速率以加快收敛速度。在训练后期,权重接近最优值,使用相对较小的学习率可以防止模型在收敛的过程中发生震荡。
Cosine学习率策略也就应运而生,Cosine学习率策略指的是学习率在训练的过程中,按照余弦的曲线变化。在整个训练过程中,Cosine学习率衰减策略使得在网络在训练初期保持了较大的学习速率,在后期学习率会逐渐衰减至0,其收敛速度相对较慢,但最终收敛精度较好。下图比较了两种不同的学习率衰减策略piecewise decay
和cosine decay
。
- 学习率预热策略
模型刚开始训练时,模型权重是随机初始化的,此时若选择一个较大的学习率,可能造成模型训练不稳定的问题,因此学习率预热的概念被提出,用于解决模型训练初期不收敛的问题。
学习率预热指的是将学习率从一个很小的值开始,逐步增加到初始较大的学习率。它可以保证模型在训练初期的稳定性。使用学习率预热策略有助于提高图像分类任务的准确性。在DBNet中,实验表明该策略也是有效的。学习率预热策略与Cosine学习率结合时,学习率的变化趋势如下代码演示。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# 具体代码实现位于
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/optimizer/__init__.py
# 导入学习率优化器构建的函数
from ppocr.optimizer import build_lr_scheduler
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# 咱们也可以看看warmup_epoch为2时的效果
lr_config = {'name': 'Cosine', 'learning_rate': 0.1, 'warmup_epoch': 0}
epochs = 20 # config['Global']['epoch_num']
iters_epoch = 100 # len(train_dataloader)
lr_scheduler=build_lr_scheduler(lr_config, epochs, iters_epoch)
iters = 0
lr = []
for epoch in range(epochs):
for _ in range(iters_epoch):
lr_scheduler.step() # 对应 https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/tools/program.py#L262
iters += 1
lr.append(lr_scheduler.get_lr())
x = np.arange(iters,dtype=np.int64)
y = np.array(lr,dtype=np.float64)
plt.figure(figsize=(15, 6))
plt.plot(x,y,color='red',label='lr')
plt.title(u'Cosine lr scheduler with Warmup')
plt.xlabel(u'iters')
plt.ylabel(u'lr')
plt.legend()
plt.show()
1
2
3
4
5
6
7
8
9
10
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
2.1.5 模型裁剪策略-FPGM
深度学习模型中一般有比较多的参数冗余,我们可以使用一些方法,去除模型中比较冗余的地方,从而提升模型推理效率。
模型裁剪指的是通过去除网络中冗余的通道(channel)、滤波器(filter)、神经元(neuron)等,来得到一个更轻量的网络,同时尽可能保证模型精度。
相比于裁剪通道或者特征图的方法,裁剪滤波器的方法可以得到更加规则的模型,因此减少内存消耗,加速模型推理过程。
之前的裁剪滤波器的方法大多基于范数进行裁剪,即,认为范数较小的滤波器重要程度较小,但是这种方法要求存在的滤波器的最小范数应该趋近于0,否则我们难以去除。
针对上面的问题,基于几何中心点的裁剪算法(Filter Pruning via Geometric Median, FPGM)被提出。FPGM将卷积层中的每个滤波器都作为欧几里德空间中的一个点,它引入了几何中位数这样一个概念,即与所有采样点距离之和最小的点。如果一个滤波器的接近这个几何中位数,那我们可以认为这个滤波器的信息和其他滤波器重合,可以去掉。
FPGM与基于范数的裁剪算法的对比如下图所示。
在PP-OCR中,我们使用FPGM对检测模型进行剪枝,最终DBNet的模型精度只有轻微下降,但是模型大小减小46%,预测速度加速19%。
关于FPGM模型裁剪实现的更多细节可以参考PaddleSlim。
注意:
- 模型裁剪需要重新训练模型,可以参考PaddleOCR剪枝教程。
- 裁剪代码是根据DBNet进行适配,如果您需要对自己的模型进行剪枝,需要重新分析模型结构、参数的敏感度,我们通常情况下只建议裁剪相对敏感度低的参数,而跳过敏感度高的参数。
- 每个卷积层的剪枝率对于裁剪后模型的性能也很重要,用完全相同的裁剪率去进行模型裁剪通常会导致显着的性能下降。
- 模型裁剪不是一蹴而就的,需要进行反复的实验,才能得到符合要求的模型。
2.1.6 文本检测配置说明
下面给出DBNet的训练配置简要说明,完整的配置文件可以参考:ch_det_mv3_db_v2.0.yml。
Architecture: # 模型结构定义
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3 # 配置骨干网络
scale: 0.5
model_name: large
disable_se: True # 去除SE模块
Neck:
name: DBFPN # 配置DBFPN
out_channels: 96 # 配置 inner_channels
Head:
name: DBHead
k: 50
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine # 配置cosine学习率下降策略
learning_rate: 0.001 # 初始学习率
warmup_epoch: 2 # 配置学习率预热策略
regularizer:
name: 'L2' # 配置L2正则
factor: 0 # 正则项的权重
2.1.7 PP-OCR 检测优化总结
上面给大家介绍了PP-OCR中文字检测算法的优化策略,这里再给大家回顾一下不同优化策略对应的消融实验与结论。
通过轻量级骨干网络、轻量级neck结构、SE模块的分析和去除、学习率调整及优化、模型裁剪等策略,DBNet的模型大小从7M减少至1.5M。通过学习率策略优化等训练策略优化,DBNet的模型精度提升超过1%。
PP-OCR中,超轻量DBNet检测效果如下所示:
下面展示快速使用文字检测模型的预测效果。具体的预测推理代码,我们在第五章会进行详细说明。
1
2
3
4
5
6
7
8
9
10
11
!mkdir inference
!cd inference && wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar -O ch_PP-OCRv2_det_infer.tar && tar -xf ch_PP-OCRv2_det_infer.tar
!python tools/infer/predict_det.py --image_dir="./doc/imgs/00111002.jpg" --det_model_dir="./inference/ch_PP-OCRv2_det_infer" --use_gpu=False
from PIL import Image
img_det = Image.open('./inference_results/det_res_00111002.jpg')
plt.figure(figsize=(14, 10)) # 图像窗口大小
plt.imshow(img_det)
plt.axis('on')
plt.title('Detection')
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
mkdir: cannot create directory ‘inference’: File exists
--2021-12-24 18:59:56-- https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar
Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3190272 (3.0M) [application/x-tar]
Saving to: ‘ch_PP-OCRv2_det_infer.tar’
ch_PP-OCRv2_det_inf 100%[===================>] 3.04M --.-KB/s in 0.1s
2021-12-24 18:59:57 (26.3 MB/s) - ‘ch_PP-OCRv2_det_infer.tar’ saved [3190272/3190272]
[2021/12/24 19:00:01] root INFO: 00111002.jpg [[[78, 641], [408, 638], [408, 659], [78, 662]], [[76, 614], [214, 614], [214, 635], [76, 635]], [[103, 554], [150, 554], [150, 576], [103, 576]], [[74, 531], [349, 531], [349, 551], [74, 551]], [[75, 503], [310, 499], [311, 523], [75, 527]], [[162, 462], [320, 462], [320, 495], [162, 495]], [[326, 432], [415, 432], [415, 453], [326, 453]], [[306, 409], [429, 407], [430, 428], [306, 430]], [[74, 411], [212, 406], [213, 426], [75, 431]], [[74, 384], [219, 382], [219, 403], [74, 405]], [[309, 381], [429, 381], [429, 402], [309, 402]], [[74, 362], [201, 359], [201, 380], [75, 383]], [[304, 358], [426, 358], [426, 378], [304, 378]], [[70, 336], [242, 332], [242, 356], [71, 359]], [[72, 312], [206, 307], [206, 328], [73, 333]], [[304, 308], [419, 308], [419, 329], [304, 329]], [[114, 271], [249, 271], [249, 302], [114, 302]], [[363, 270], [383, 270], [383, 297], [363, 297]], [[68, 248], [246, 246], [246, 269], [69, 271]], [[65, 218], [188, 218], [188, 242], [65, 242]], [[337, 215], [384, 215], [384, 241], [337, 241]], [[67, 196], [248, 196], [248, 216], [67, 216]], [[296, 196], [424, 190], [425, 211], [296, 217]], [[65, 167], [245, 167], [245, 188], [65, 188]], [[67, 138], [290, 138], [290, 159], [67, 159]], [[68, 112], [411, 112], [411, 132], [68, 132]], [[278, 86], [417, 86], [417, 107], [278, 107]], [[167, 60], [412, 61], [412, 74], [167, 73]], [[165, 17], [412, 16], [412, 51], [165, 52]], [[7, 6], [61, 6], [61, 24], [7, 24]]]
[2021/12/24 19:00:01] root INFO: The predict time of ./doc/imgs/00111002.jpg: 2.0092978477478027
[2021/12/24 19:00:01] root INFO: The visualized image saved in ./inference_results/det_res_00111002.jpg
2.2 方向分类器
方向分类器的任务是用于分类出文本检测出的文本实例的方向,将文本旋转到0度之后,再送入后续的文本识别器中。PP-OCR中,我们考虑了0度和180度2个方向。下面详细介绍针对方向分类器的速度、精度优化策略。
2.2.1 轻量级骨干网络
与文本检测器相同,我们仍然采用MobileNetV3作为方向分类器的骨干网络。因为方向分类的任务相对简单,我们使用MobileNetV3 small 0.35x来平衡模型精度与预测效率。实验表明,即使当使用更大的骨干时,精度不会有进一步的提升。
2.2.2 数据增强
数据增强指的是对图像变换,送入网络进行训练,它可以提升网络的泛化性能。常用的数据增强包括旋转、透视失真变换、运动模糊变换和高斯噪声变换等,PP-OCR中,我们统称这些数据增强方法为BDA(Base Data Augmentation)。结果表明,BDA可以明显提升方向分类器的精度。
下面展示一些BDA数据增广方法的效果
除了BDA外,我们还加入了一些更高阶的数据增强操作来提高分类的效果,例如 AutoAugment (Cubuk et al. 2019), RandAugment (Cubuk et al. 2020), CutOut (DeVries and Taylor 2017), RandErasing (Zhong et al. 2020), HideAndSeek (Singh and Lee 2017), GridMask (Chen 2020), Mixup (Zhang et al. 2017) 和 Cutmix (Yun et al. 2019)。
这些数据增广大体分为3个类别:
(1)图像变换类:AutoAugment、RandAugment
(2)图像裁剪类:CutOut、RandErasing、HideAndSeek、GridMask
(3)图像混叠类:Mixup、Cutmix
下面给出不同高阶数据增广的可视化对比结果。
dErasing 外,大多数方法都不适用于方向分类器。下图也给出了在不同数据增强策略下,模型精度的变化。
最终,我们在训练时结合BDA和RandAugment,作为方向分类器的数据增强策略。
- RandAugment代码演示
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# 参考代码:
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/data/imaug/__init__.py
import random
from PIL import Image
from ppocr.data.imaug import DecodeImage, RandAugment, transform
np.random.seed(1)
random.seed(1)
img = Image.open('./doc/imgs_words/ch/word_4.jpg')
# 绘制原图
plt.figure("Image1") # 图像窗口名称
plt.imshow(img)
plt.axis('on') # 关掉坐标轴为 off
plt.title('Before RandAugment') # 图像题目
plt.show()
data = {'image':None}
with open('./doc/imgs_words/ch/word_4.jpg', 'rb') as f:
img = f.read()
data['image'] = img
# 定义变换算子
ops_list = [DecodeImage(), RandAugment()]
# 数据变换
data = transform(data,ops_list)
img_auged = data['image']
# 显示
img_auged = Image.fromarray(img_auged, 'RGB')
plt.figure("Image") # 图像窗口名称
plt.imshow(img_auged)
plt.axis('on') # 关掉坐标轴为 off
plt.title('After RandAugment') # 图像标题
plt.show()
下面展示快速使用方向分类器模型的预测效果。具体的预测推理代码,我们在第五章会进行详细说明。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 参考代码:
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/tools/infer/predict_cls.py
!cd inference && wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar -O ch_ppocr_mobile_v2.0_cls_infer.tar && tar -xf ch_ppocr_mobile_v2.0_cls_infer.tar
# 方向分类器分类
!python tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_1.jpg" --cls_model_dir="./inference/ch_ppocr_mobile_v2.0_cls_infer" --use_gpu=False
# 读入图像
import cv2
img = cv2.imread("./doc/imgs_words/ch/word_1.jpg")
plt.imshow(img[:,:,::-1])
plt.show()
# 旋转180度
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
cv2.imwrite("./test.png", img)
# 对旋转后图像使用方向分类器进行分类
!python tools/infer/predict_cls.py --image_dir="./test.png" --cls_model_dir="./inference/ch_ppocr_mobile_v2.0_cls_infer" --use_gpu=False
plt.imshow(img[:,:,::-1])
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
--2021-12-24 19:00:05-- https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1454080 (1.4M) [application/x-tar]
Saving to: ‘ch_ppocr_mobile_v2.0_cls_infer.tar’
ch_ppocr_mobile_v2. 100%[===================>] 1.39M --.-KB/s in 0.1s
2021-12-24 19:00:05 (12.4 MB/s) - ‘ch_ppocr_mobile_v2.0_cls_infer.tar’ saved [1454080/1454080]
[2021/12/24 19:00:08] root INFO: Predicts of ./doc/imgs_words/ch/word_1.jpg:['0', 0.9998784]
[2021/12/24 19:00:11] root INFO: Predicts of ./test.png:['180', 0.9999759]
2.2.3 输入分辨率优化
一般来说,当图像的输入分辨率提高时,精度也会提高。由于方向分类器的骨干网络参数量很小,即使提高了分辨率也不会导致推理时间的明显增加。我们将方向分类器的输入图像尺度从3x32x100
增加到3x48x192
,方向分类器的精度从92.1%
提升至94.0%
,但是预测耗时仅仅从3.19ms
提升至3.21ms
。
2.2.4 模型量化策略-PACT
模型量化是一种将浮点计算转成低比特定点计算的技术,可以使神经网络模型具有更低的延迟、更小的体积以及更低的计算功耗。
模型量化主要分为离线量化和在线量化。其中,离线量化是指一种利用KL散度等方法来确定量化参数的定点量化方法,量化后不需要再次训练;在线量化是指在训练过程中确定量化参数,相比离线量化模式,它的精度损失更小。
PACT(PArameterized Clipping acTivation)是一种新的在线量化方法,可以提前从激活层中去除一些极端值。在去除极端值后,模型可以学习更合适的量化参数。普通PACT方法的激活值的预处理是基于RELU函数的,公式如下:
%3D0.5(%7Cx%7C-%7Cx-%5Calpha%7C%2B%5Calpha)%3D%5Cleft%5C%7B%5Cbegin%7Barray%7D%7Bcc%7D%0A0%20%26%20x%20%5Cin(-%5Cinfty%2C%200)%20%5C%5C%0Ax%20%26%20x%20%5Cin%5B0%2C%20%5Calpha)%20%5C%5C%0A%5Calpha%20%26%20x%20%5Cin%5B%5Calpha%2C%2B%5Cinfty)%0A%5Cend%7Barray%7D%5Cright.%0A#card=math&code=y%3DP%20A%20C%20T%28x%29%3D0.5%28%7Cx%7C-%7Cx-%5Calpha%7C%2B%5Calpha%29%3D%5Cleft%5C%7B%5Cbegin%7Barray%7D%7Bcc%7D%0A0%20%26%20x%20%5Cin%28-%5Cinfty%2C%200%29%20%5C%5C%0Ax%20%26%20x%20%5Cin%5B0%2C%20%5Calpha%29%20%5C%5C%0A%5Calpha%20%26%20x%20%5Cin%5B%5Calpha%2C%2B%5Cinfty%29%0A%5Cend%7Barray%7D%5Cright.%0A&id=FRURJ)
所有大于特定阈值的激活值都会被重置为一个常数。然而,MobileNetV3中的激活函数不仅是ReLU,还包括hardswish。因此使用普通的PACT量化会导致更高的精度损失。因此,为了减少量化损失,我们将激活函数的公式修改为:
%3D%5Cleft%5C%7B%5Cbegin%7Barray%7D%7Brl%7D%0A-%5Calpha%20%26%20x%20%5Cin(-%5Cinfty%2C-%5Calpha)%20%5C%5C%0Ax%20%26%20x%20%5Cin%5B-%5Calpha%2C%20%5Calpha)%20%5C%5C%0A%5Calpha%20%26%20x%20%5Cin%5B%5Calpha%2C%2B%5Cinfty)%0A%5Cend%7Barray%7D%5Cright.%0A#card=math&code=y%3DP%20A%20C%20T%28x%29%3D%5Cleft%5C%7B%5Cbegin%7Barray%7D%7Brl%7D%0A-%5Calpha%20%26%20x%20%5Cin%28-%5Cinfty%2C-%5Calpha%29%20%5C%5C%0Ax%20%26%20x%20%5Cin%5B-%5Calpha%2C%20%5Calpha%29%20%5C%5C%0A%5Calpha%20%26%20x%20%5Cin%5B%5Calpha%2C%2B%5Cinfty%29%0A%5Cend%7Barray%7D%5Cright.%0A&id=CjKnq)
PaddleOCR中提供了适用于PP-OCR套件的量化脚本。具体链接可以参考PaddleOCR模型量化教程。
2.2.5 方向分类器配置说明
训练方向分类器时,配置文件中的部分关键字段和说明如下所示。完整配置文件可以参考cls_mv3.yml。
Architecture:
model_type: cls
algorithm: CLS
Transform:
Backbone:
name: MobileNetV3 # 配置分类模型为MobileNetV3
scale: 0.35
model_name: small
Neck:
Head:
name: ClsHead
class_dim: 2
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/cls
label_file_list:
- ./train_data/cls/train.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ClsLabelEncode: # Class handling label
- RecAug:
use_tia: False # 配置BDA数据增强,不使用TIA数据增强
- RandAugment: # 配置随机增强数据增强方法
- ClsResizeImg:
image_shape: [3, 48, 192] # 这里将[3, 32, 100]修改为[3, 48, 192],进行输入分辨率优化
- KeepKeys:
keep_keys: ['image', 'label'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 512
drop_last: True
num_workers: 8
2.2.5 方向分类器实验总结
在方向分类器模型优化中,我们使用轻量化骨干网络以及模型量化,最终将模型从0.85M降低到了0.46M,使用组合数据增广、高分辨率等特征,最终将模型精度提升了超过2%。消融实验对比如下所示。
2.3 文本识别
PP-OCR中,文本识别器使用的是CRNN模型。训练的时候使用CTC loss去解决不定长文本的预测问题。
CRNN模型结构如下所示。
CRNN结构图
PP-OCR针对文本识别器,从骨干网络、头部结构优化、数据增强、正则化策略、特征图下采样策略、量化等多个角度进行模型优化,具体消融实验如下所示。
CRNN识别模型消融实验
下面详细介绍文本识别模型的具体优化策略。
2.3.1 轻量级骨干网络和头部结构
- 轻量级骨干网络
在文本识别中,仍然采用了与文本检测相同的MobileNetV3作为backbone。选自MobileNetV3_small_x0.5进一步地平衡精度和效率。如果不要求模型大小的话,可以选择MobileNetV3_small_x1,模型大小仅增加5M,精度明显提高。
不同骨干网络下的识别模型精度对比
- 轻量级头部结构
CRNN中,用于解码的轻量级头(head)是一个全连接层,用于将序列特征解码为普通的预测字符。序列特征的维数对文本识别器的模型大小影响非常大,特别是对于6000多个字符的中文识别场景(序列特征维度若设置为256,则仅仅是head部分的模型大小就为6.7M)。在PP-OCR中,我们针对序列特征的维度展开实验,最终将其设置为48,平衡了精度与效率。部分消融实验结论如下。
不同序列特征维度的精度对比
2.3.2 数据增强
除了前面提到的经常用于文本识别的BDA(基本数据增强),TIA(Luo等人,2020)也是一种有效的文本识别数据增强方法。TIA是一种针对场景文字的数据增强方法,它在图像中设置了多个基准点,然后随机移动点,通过几何变换生成新图像,这样大大提升了数据的多样性以及模型的泛化能力。TIA的基本流程图如图所示:
实验证明,使用TIA数据增广,可以帮助文本识别模型的精度在一个极高的baseline上面进一步提升0.9%。
下面是TIA中三种涉及到的数据增广的可视化效果图。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 参考代码:
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/data/imaug/text_image_aug/augment.py
import cv2
from ppocr.data.imaug.rec_img_aug import tia_distort, tia_stretch, tia_perspective
img = cv2.imread("./doc/imgs_words/ch/word_1.jpg")
img_out1 = tia_distort(img, 2.5)
img_out2 = tia_stretch(img, 3)
img_out3 = tia_perspective(img)
plt.figure(figsize=(20, 8))
plt.subplot(1,4,1)
plt.imshow(img[:,:,::-1])
plt.subplot(1,4,2)
plt.imshow(img_out1[:,:,::-1])
plt.subplot(1,4,3)
plt.imshow(img_out2[:,:,::-1])
plt.subplot(1,4,4)
plt.imshow(img_out3[:,:,::-1])
plt.show()
2.3.3 学习率策略和正则化
在识别模型训练中,学习率下降策略与文本检测相同,也使用了Cosine+Warmup的学习率策略。
正则化是一种广泛使用的避免过度拟合的方法,一般包含L1正则化和L2正则化。在大多数使用场景中,我们都使用L2正则化。它主要的原理就是计算网络中权重的L2范数,添加到损失函数中。在L2正则化的帮助下,网络的权重趋向于选择一个较小的值,最终整个网络中的参数趋向于0,从而缓解模型的过拟合问题,提高了模型的泛化性能。
我们实验发现,对于文本识别,L2正则化对识别准确率有很大的影响。
CRNN识别模型消融实验
2.3.4 特征图降采样策略
我们在做检测、分割、OCR等下游视觉任务时,骨干网络一般都是使用的图像分类任务中的骨干网络,它的输入分辨率一般设置为224x224,降采样时,一般宽度和高度会同时降采样。
但是对于文本识别任务来说,由于输入图像一般是32x100,长宽比非常不平衡,此时对宽度和高度同时降采样,会导致特征损失严重,因此图像分类任务中的骨干网络应用到文本识别任务中需要进行特征图降采样方面的适配(如果大家自己换骨干网络的话,这里也需要注意一下)。
在PaddleOCR中,CRNN中文文本识别模型设置的输入图像的高度和宽度设置为32和320。原始MobileNetV3来自分类模型,如前文所述,需要调整降采样的步长,适配文本图像输入分辨率。具体地,为了保留更多的水平信息,我们将下采样特征图的步长从 (2,2) 修改为 (2,1) ,第一次下采样除外。最终如下图所示。
降采样步长策略优化可视化
为了保留更多的垂直信息,我们进一步将第二次下采样特征图的步长从 (2,1) 修改为 (1,1)。因此,第二个下采样特征图的步长s2会显著影响整个特征图的分辨率和文本识别器的准确性。在PP-OCR中,s2被设置为(1,1),可以获得更好的性能。同时,由于水平的分辨率增加,CPU的推理时间从11.84ms
增加到 12.96ms
。
下面给出了stride优化前后的特征图尺度对比。虽然最终输出特征图尺度相同,但是stride从(2,1)修改为(1,1)之后,特征信息在编码的过程中被保留得更为完整。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 参考代码:
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/backbones/rec_mobilenet_v3.py
from ppocr.modeling.backbones.rec_mobilenet_v3 import MobileNetV3
mv3_ori = MobileNetV3(model_name="small", scale=0.5, small_stride=[2,2,2,2])
mv3_new = MobileNetV3(model_name="small", scale=0.5, small_stride=[1,2,2,2])
x = paddle.rand([1, 3, 32, 320])
y_ori = mv3_ori(x)
y_new = mv3_new(x)
print(y_ori.shape)
print(y_new.shape)
1
2
[1, 288, 1, 80]
[1, 288, 1, 80]
2.3.5 PACT 在线量化策略
我们采用与方向分类器量化类似的方案来减小文本识别器的模型大小。由于LSTM量化的复杂性,PP-OCR中没有对LSTM进行量化。使用该量化策略之后,模型大小减小67.4%
、预测速度加速8%
、准确率提升1.6%
,量化可以减少模型冗余,增强模型的表达能力。
模型量化消融实验
2.3.6 文字识别预训练模型
使用合适的预训练模型可以加快模型的收敛速度。在真实场景中,用于文本识别的数据通常是有限的。PP-OCR中,我们合成了千万级别的数据,对模型进行训练,之后再基于该模型,在真实数据上微调,最终识别准确率从从65.81%
提升到69%
。
2.3.7 文本识别配置说明
下面给出CRNN的训练配置简要说明,完整的配置文件可以参考:rec_chinese_lite_train_v2.0.yml。
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine # 配置Cosine 学习率下降策略
learning_rate: 0.001
warmup_epoch: 5 # 配置预热学习率
regularizer:
name: 'L2' # 配置L2正则
factor: 0.00001
Architecture:
model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV3 # 配置Backbone
scale: 0.5
model_name: small
small_stride: [1, 2, 2, 2] # 配置下采样的stride
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48 # 配置最后一层全连接层的维度
Head:
name: CTCHead
fc_decay: 0.00001
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list: ["./train_data/train_list.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- RecAug: # 配置数据增强BDA和TIA,TIA默认使用
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 256
drop_last: True
num_workers: 8
2.3.8 识别优化小结
在模型体积方面,PP-OCR使用轻量级骨干网络、序列维度裁剪、模型量化的策略,将模型大小从4.5M减小至1.6M。在精度方面,使用TIA数据增强、Cosine-warmup学习率策略、L2正则、特征图分辨率改进、预训练模型等优化策略,最终在验证集上提升15.4%
。
PP-OCR中部分识别效果如下所示。
文本识别模型的代码演示如下。
1
2
3
4
5
6
7
# 可视化原图
img = cv2.imread("./doc/imgs_words/ch/word_4.jpg")
plt.imshow(img[..., ::-1])
plt.show()
!cd inference && wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar -O ch_PP-OCRv2_rec_infer.tar && tar -xf ch_PP-OCRv2_rec_infer.tar
!python tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="./inference/ch_PP-OCRv2_rec_infer" --use_gpu=False
1
2
3
4
5
6
7
8
9
10
11
--2021-12-24 19:00:26-- https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar
Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8875520 (8.5M) [application/x-tar]
Saving to: ‘ch_PP-OCRv2_rec_infer.tar’
ch_PP-OCRv2_rec_inf 100%[===================>] 8.46M 39.0MB/s in 0.2s
2021-12-24 19:00:27 (39.0 MB/s) - ‘ch_PP-OCRv2_rec_infer.tar’ saved [8875520/8875520]
[2021/12/24 19:00:29] root INFO: Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.9409585)
1
2
3
4
5
6
7
# 对 ./doc/imgs_words/ch/word_1.jpg 旋转180度得到
img = cv2.imread("./test.png")
plt.imshow(img[:,:,::-1])
plt.show()
!python tools/infer/predict_rec.py --image_dir="./test.png" --rec_model_dir="./inference/ch_PP-OCRv2_rec_infer" --use_gpu=False
1
2
3
4
5
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3420: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/core/_methods.py:188: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
[2021/12/24 19:00:32] root INFO: Predicts of ./test.png:('', nan)
3. PP-OCRv2优化策略解读
第2节的内容主要是对PP-OCR以及它的19个优化策略进行了详细介绍。
相比于PP-OCR, PP-OCRv2 在骨干网络、数据增广、损失函数这三个方面进行进一步优化,解决端侧预测效率较差、背景复杂以及相似字符的误识等问题,同时引入了知识蒸馏训练策略,进一步提升模型精度。具体地:
- 检测模型优化: (1) 采用 CML 协同互学习知识蒸馏策略;(2) CopyPaste 数据增广策略;
- 识别模型优化: (1) PP-LCNet 轻量级骨干网络;(2) U-DML 改进知识蒸馏策略; (3) Enhanced CTC loss 损失函数改进。
本节主要基于文字检测和识别模型的优化过程,去解读PP-OCRv2的优化策略。
3.1 文字检测模型优化详解
文字检测模型优化过程中,采用 CML 协同互学习知识蒸馏以及 CopyPaste 数据增广策略;最终将文字检测模型在大小不变的情况下,Hmean从 0.759 提升至 0.795,具体消融实验如下所示。
PP-OCRv2检测模型消融实验
3.1.1 CML知识蒸馏策略
知识蒸馏的方法在部署中非常常用,通过使用大模型指导小模型学习的方式,在通常情况下可以使得小模型在预测耗时不变的情况下,精度得到进一步的提升,从而进一步提升实际部署的体验。
标准的蒸馏方法是通过一个大模型作为 Teacher 模型来指导 Student 模型提升效果,而后来又发展出 DML 互学习蒸馏方法,即通过两个结构相同的模型互相学习,相比于前者,DML 脱离了对大的 Teacher 模型的依赖,蒸馏训练的流程更加简单,模型产出效率也要更高一些。
PP-OCRv2 文字检测模型中使用的是三个模型之间的 CML (Collaborative Mutual Learning) 协同互蒸馏方法,既包含两个相同结构的 Student 模型之间互学习,同时还引入了较大模型结构的 Teacher 模型。CML与其他蒸馏算法的对比如下所示。
具体地,文本检测任务中,CML的结构框图如下所示。这里的 response maps 指的就是DBNet最后一层的概率图输出 (Probability map) 。在整个训练过程中,总共包含3个损失函数。
- GT loss
- DML loss
- Distill loss
这里的 Teacher 模型的骨干网络为 ResNet18_vd,2 个 Student 模型的骨干网络为 MobileNetV3。
- GT loss
两个 Student 模型中大部分的参数都是从头初始化的,因此它们在训练的过程中需要受到 groundtruth (GT) 信息 的监督。DBNet 训练任务的 pipeline 如下所示。其输出主要包含 3 种 feature map,具体如下所示。
对这 3 种 feature map 使用不同的 loss function 进行监督,具体如下表所示。
Feature map | Loss function | weight |
---|---|---|
Probability map | Binary cross-entropy loss | 1.0 |
Binary map | Dice loss | |
Threshold map | L1 loss |
最终GT loss可以表示为如下所示。
%20%3D%20l%7Bp%7D(S%7Bout%7D%2C%20gt)%20%2B%20%5Calpha%20l%7Bb%7D(S%7Bout%7D%2C%20gt)%20%2B%20%5Cbeta%20l%7Bt%7D(S%7Bout%7D%2C%20gt)%20%0A#card=math&code=Loss%7Bgt%7D%28T%7Bout%7D%2C%20gt%29%20%3D%20l%7Bp%7D%28S%7Bout%7D%2C%20gt%29%20%2B%20%5Calpha%20l%7Bb%7D%28S%7Bout%7D%2C%20gt%29%20%2B%20%5Cbeta%20l%7Bt%7D%28S%7Bout%7D%2C%20gt%29%20%0A&id=IukbW)
- DML loss
对于 2 个完全相同的 Student 模型来说,因为它们的结构完全相同,因此对于相同的输入,应该具有相同的输出,DBNet 最终输出的是概率图 (response maps),因此基于 KL 散度,计算 2 个 Student 模型的 DML loss,具体计算方式如下。
%20%2B%20KL(S2%7Bpout%7D%20%7C%7C%20S1%7Bpout%7D)%7D%7B2%7D%20%0A#card=math&code=Loss%7Bdml%7D%20%3D%20%5Cfrac%7BKL%28S1%7Bpout%7D%20%7C%7C%20S2%7Bpout%7D%29%20%2B%20KL%28S2%7Bpout%7D%20%7C%7C%20S1_%7Bpout%7D%29%7D%7B2%7D%20%0A&id=lj6oC)
其中 KL(·|·)
是 KL 散度的计算公式,最终这种形式的 DML loss 具有对称性。
- Distill loss
CML 中,引入了 Teacher 模型,来同时监督 2 个 Student 模型。PP-OCRv2 中只对特征 Probability map
进行蒸馏的监督。具体地,对于其中一个 Student 模型,计算方法如下所示, lp(·) 和 lb(·) 分别表示 Binary cross-entropy loss 和 Dice loss。另一个 Student 模型的 loss 计算过程完全相同。
)%20%2B%20l%7Bb%7D(S%7Bout%7D%2C%20f%7Bdila%7D(T%7Bout%7D))%20%0A#card=math&code=Loss%7Bdistill%7D%20%3D%20%5Cgamma%20l%7Bp%7D%28S%7Bout%7D%2C%20f%7Bdila%7D%28T%7Bout%7D%29%29%20%2B%20l%7Bb%7D%28S%7Bout%7D%2C%20f%7Bdila%7D%28T_%7Bout%7D%29%29%20%0A&id=MpBQu)
最终,将上述三个 loss 相加,就得到了用于 CML 训练的损失函数。
检测配置文件为ch_PP-OCRv2_det_cml.yml,蒸馏结构部分的配置和部分解释如下。
Architecture:
name: DistillationModel # 模型名称,这是通用的蒸馏模型表示。
algorithm: Distillation # 算法名称,
Models: # 模型,包含子网络的配置信息
Teacher: # Teacher子网络,包含`pretrained`与`freeze_params`信息以及其他用于构建子网络的参数
freeze_params: true # 是否固定Teacher网络的参数
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy # 预训练模型
return_all_feats: false # 是否返回所有的特征,为True时,会将backbone、neck、head等模块的输出都返回
model_type: det # 模型类别
algorithm: DB # Teacher网络的算法名称
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Student: # Student子网络
freeze_params: false
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Student2: # Student2子网络
freeze_params: false
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
DistillationModel
类的实现在distillation_model.py文件中,DistillationModel
类的实现与部分讲解如下。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class DistillationModel(nn.Layer):
def __init__(self, config):
"""
the module for OCR distillation.
args:
config (dict): the super parameters for module.
"""
super().__init__()
self.model_list = []
self.model_name_list = []
# 根据Models中的每个字段,抽取出子网络的名称以及对应的配置
for key in config["Models"]:
model_config = config["Models"][key]
freeze_params = False
pretrained = None
if "freeze_params" in model_config:
freeze_params = model_config.pop("freeze_params")
if "pretrained" in model_config:
pretrained = model_config.pop("pretrained")
# 根据每个子网络的配置,基于BaseModel生成子网络
model = BaseModel(model_config)
# 判断是否加载预训练模型
if pretrained is not None:
load_pretrained_params(model, pretrained)
# 判断是否需要固定该子网络的模型参数
if freeze_params:
for param in model.parameters():
param.trainable = False
self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key)
def forward(self, x):
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
result_dict[model_name] = self.model_list[idx](x)
return result_dict
使用下面的命令,可以快速完成蒸馏模型的初始化过程。
1
2
3
4
5
6
7
8
# 参考代码
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/architectures/__init__.py
from tools.program import load_config
from ppocr.modeling.architectures import build_model
config_path = "./configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml"
config = load_config(config_path)
model = build_model(config['Architecture'])
print(model)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
DistillationModel(
(Teacher): BaseModel(
(backbone): ResNet(
(conv1_1): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(3, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1_2): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(32, 32, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1_3): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(32, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(pool2d_max): MaxPool2D(kernel_size=3, stride=2, padding=1)
(bb_0_0): BasicBlock(
(conv0): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(64, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(64, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(short): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(64, 64, kernel_size=[1, 1], data_format=NCHW)
(_batch_norm): BatchNorm()
)
)
(bb_0_1): BasicBlock(
(conv0): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(64, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(64, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
)
(bb_1_0): BasicBlock(
(conv0): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(64, 128, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(128, 128, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(short): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(64, 128, kernel_size=[1, 1], data_format=NCHW)
(_batch_norm): BatchNorm()
)
)
(bb_1_1): BasicBlock(
(conv0): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(128, 128, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(128, 128, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
)
(bb_2_0): BasicBlock(
(conv0): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(128, 256, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(256, 256, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(short): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(128, 256, kernel_size=[1, 1], data_format=NCHW)
(_batch_norm): BatchNorm()
)
)
(bb_2_1): BasicBlock(
(conv0): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(256, 256, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(256, 256, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
)
(bb_3_0): BasicBlock(
(conv0): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(256, 512, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(512, 512, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(short): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(256, 512, kernel_size=[1, 1], data_format=NCHW)
(_batch_norm): BatchNorm()
)
)
(bb_3_1): BasicBlock(
(conv0): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(512, 512, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
(conv1): ConvBNLayer(
(_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)
(_conv): Conv2D(512, 512, kernel_size=[3, 3], padding=1, data_format=NCHW)
(_batch_norm): BatchNorm()
)
)
)
(neck): DBFPN(
(in2_conv): Conv2D(64, 256, kernel_size=[1, 1], data_format=NCHW)
(in3_conv): Conv2D(128, 256, kernel_size=[1, 1], data_format=NCHW)
(in4_conv): Conv2D(256, 256, kernel_size=[1, 1], data_format=NCHW)
(in5_conv): Conv2D(512, 256, kernel_size=[1, 1], data_format=NCHW)
(p5_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p4_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p3_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p2_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
)
(head): DBHead(
(binarize): Head(
(conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(conv_bn1): BatchNorm()
(conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
(conv_bn2): BatchNorm()
(conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
)
(thresh): Head(
(conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
(conv_bn1): BatchNorm()
(conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
(conv_bn2): BatchNorm()
(conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
)
)
)
(Student): BaseModel(
(backbone): MobileNetV3(
(conv): ConvBNLayer(
(conv): Conv2D(3, 8, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)
(bn): BatchNorm()
)
(stage0): Sequential(
(0): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(8, 8, kernel_size=[3, 3], padding=1, groups=8, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(1): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(8, 32, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(32, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=32, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(32, 16, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(2): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(40, 40, kernel_size=[3, 3], padding=1, groups=40, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(40, 16, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
)
(stage1): Sequential(
(0): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(40, 40, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=40, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(40, 24, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(1): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(2): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
)
(stage2): Sequential(
(0): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(24, 120, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(120, 120, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=120, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(120, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(1): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(40, 104, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(104, 104, kernel_size=[3, 3], padding=1, groups=104, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(104, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(2): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(3): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(4): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(40, 240, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(240, 240, kernel_size=[3, 3], padding=1, groups=240, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(240, 56, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(5): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(336, 336, kernel_size=[3, 3], padding=1, groups=336, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(336, 56, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
)
(stage3): Sequential(
(0): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(336, 336, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=336, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(336, 80, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(1): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(2): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(3): ConvBNLayer(
(conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
)
(neck): DBFPN(
(in2_conv): Conv2D(16, 96, kernel_size=[1, 1], data_format=NCHW)
(in3_conv): Conv2D(24, 96, kernel_size=[1, 1], data_format=NCHW)
(in4_conv): Conv2D(56, 96, kernel_size=[1, 1], data_format=NCHW)
(in5_conv): Conv2D(480, 96, kernel_size=[1, 1], data_format=NCHW)
(p5_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p4_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p3_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p2_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
)
(head): DBHead(
(binarize): Head(
(conv1): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(conv_bn1): BatchNorm()
(conv2): Conv2DTranspose(24, 24, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
(conv_bn2): BatchNorm()
(conv3): Conv2DTranspose(24, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
)
(thresh): Head(
(conv1): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(conv_bn1): BatchNorm()
(conv2): Conv2DTranspose(24, 24, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
(conv_bn2): BatchNorm()
(conv3): Conv2DTranspose(24, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
)
)
)
(Student2): BaseModel(
(backbone): MobileNetV3(
(conv): ConvBNLayer(
(conv): Conv2D(3, 8, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)
(bn): BatchNorm()
)
(stage0): Sequential(
(0): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(8, 8, kernel_size=[3, 3], padding=1, groups=8, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(1): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(8, 32, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(32, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=32, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(32, 16, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(2): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(40, 40, kernel_size=[3, 3], padding=1, groups=40, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(40, 16, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
)
(stage1): Sequential(
(0): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(40, 40, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=40, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(40, 24, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(1): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(2): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
)
(stage2): Sequential(
(0): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(24, 120, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(120, 120, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=120, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(120, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(1): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(40, 104, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(104, 104, kernel_size=[3, 3], padding=1, groups=104, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(104, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(2): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(3): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(4): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(40, 240, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(240, 240, kernel_size=[3, 3], padding=1, groups=240, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(240, 56, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(5): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(336, 336, kernel_size=[3, 3], padding=1, groups=336, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(336, 56, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
)
(stage3): Sequential(
(0): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(336, 336, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=336, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(336, 80, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(1): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(2): ResidualUnit(
(expand_conv): ConvBNLayer(
(conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
(bottleneck_conv): ConvBNLayer(
(conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)
(bn): BatchNorm()
)
(linear_conv): ConvBNLayer(
(conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
(3): ConvBNLayer(
(conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)
(bn): BatchNorm()
)
)
)
(neck): DBFPN(
(in2_conv): Conv2D(16, 96, kernel_size=[1, 1], data_format=NCHW)
(in3_conv): Conv2D(24, 96, kernel_size=[1, 1], data_format=NCHW)
(in4_conv): Conv2D(56, 96, kernel_size=[1, 1], data_format=NCHW)
(in5_conv): Conv2D(480, 96, kernel_size=[1, 1], data_format=NCHW)
(p5_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p4_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p3_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(p2_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
)
(head): DBHead(
(binarize): Head(
(conv1): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(conv_bn1): BatchNorm()
(conv2): Conv2DTranspose(24, 24, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
(conv_bn2): BatchNorm()
(conv3): Conv2DTranspose(24, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
)
(thresh): Head(
(conv1): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)
(conv_bn1): BatchNorm()
(conv2): Conv2DTranspose(24, 24, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
(conv_bn2): BatchNorm()
(conv3): Conv2DTranspose(24, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
)
)
)
)
可以通过下面的方式快速体验CML蒸馏的训练过程。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 参考代码
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/tools/train.py
os.chdir("/home/aistudio/PaddleOCR/")
!mkdir train_data
!wget https://paddleocr.bj.bcebos.com/dataset/det_data_lesson_demo.tar -O det_data_lesson_demo.tar && tar -xf det_data_lesson_demo.tar && rm det_data_lesson_demo.tar
!mkdir pretrain_models && wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar && tar -xf ch_ppocr_server_v2.0_det_train.tar
# !mv ch_ppocr_server_v2.0_det_train pretrain_models/ && rm ch_ppocr_server_v2.0_det_train.tar
# 训练脚本
# 注意:这里只训练了一个epoch,仅用于快速演示,指标会很差
!python tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml \
-o Global.pretrained_model="" \
Train.dataset.data_dir="./det_data_lesson_demo/" \
Train.dataset.label_file_list=["./det_data_lesson_demo/train.txt"] \
Train.loader.num_workers=0 \
Eval.dataset.data_dir="./det_data_lesson_demo/" \
Eval.dataset.label_file_list=["./det_data_lesson_demo/eval.txt"] \
Eval.loader.num_workers=0 \
Optimizer.lr.learning_rate=0.00025 \
Global.epoch_num=1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
mkdir: cannot create directory ‘train_data’: File exists
--2021-12-24 19:00:42-- https://paddleocr.bj.bcebos.com/dataset/det_data_lesson_demo.tar
Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.229, 182.61.200.195, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.229|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 449465021 (429M) [application/x-tar]
Saving to: ‘det_data_lesson_demo.tar’
det_data_lesson_dem 100%[===================>] 428.64M 43.4MB/s in 11s
2021-12-24 19:00:52 (40.8 MB/s) - ‘det_data_lesson_demo.tar’ saved [449465021/449465021]
mkdir: cannot create directory ‘pretrain_models’: File exists
[2021/12/24 19:00:59] root INFO: Architecture :
[2021/12/24 19:00:59] root INFO: Models :
[2021/12/24 19:00:59] root INFO: Student :
[2021/12/24 19:00:59] root INFO: Backbone :
[2021/12/24 19:00:59] root INFO: disable_se : True
[2021/12/24 19:00:59] root INFO: model_name : large
[2021/12/24 19:00:59] root INFO: name : MobileNetV3
[2021/12/24 19:00:59] root INFO: scale : 0.5
[2021/12/24 19:00:59] root INFO: Head :
[2021/12/24 19:00:59] root INFO: k : 50
[2021/12/24 19:00:59] root INFO: name : DBHead
[2021/12/24 19:00:59] root INFO: Neck :
[2021/12/24 19:00:59] root INFO: name : DBFPN
[2021/12/24 19:00:59] root INFO: out_channels : 96
[2021/12/24 19:00:59] root INFO: algorithm : DB
[2021/12/24 19:00:59] root INFO: freeze_params : False
[2021/12/24 19:00:59] root INFO: model_type : det
[2021/12/24 19:00:59] root INFO: return_all_feats : False
[2021/12/24 19:00:59] root INFO: Student2 :
[2021/12/24 19:00:59] root INFO: Backbone :
[2021/12/24 19:00:59] root INFO: disable_se : True
[2021/12/24 19:00:59] root INFO: model_name : large
[2021/12/24 19:00:59] root INFO: name : MobileNetV3
[2021/12/24 19:00:59] root INFO: scale : 0.5
[2021/12/24 19:00:59] root INFO: Head :
[2021/12/24 19:00:59] root INFO: k : 50
[2021/12/24 19:00:59] root INFO: name : DBHead
[2021/12/24 19:00:59] root INFO: Neck :
[2021/12/24 19:00:59] root INFO: name : DBFPN
[2021/12/24 19:00:59] root INFO: out_channels : 96
[2021/12/24 19:00:59] root INFO: Transform : None
[2021/12/24 19:00:59] root INFO: algorithm : DB
[2021/12/24 19:00:59] root INFO: freeze_params : False
[2021/12/24 19:00:59] root INFO: model_type : det
[2021/12/24 19:00:59] root INFO: return_all_feats : False
[2021/12/24 19:00:59] root INFO: Teacher :
[2021/12/24 19:00:59] root INFO: Backbone :
[2021/12/24 19:00:59] root INFO: layers : 18
[2021/12/24 19:00:59] root INFO: name : ResNet
[2021/12/24 19:00:59] root INFO: Head :
[2021/12/24 19:00:59] root INFO: k : 50
[2021/12/24 19:00:59] root INFO: name : DBHead
[2021/12/24 19:00:59] root INFO: Neck :
[2021/12/24 19:00:59] root INFO: name : DBFPN
[2021/12/24 19:00:59] root INFO: out_channels : 256
[2021/12/24 19:00:59] root INFO: Transform : None
[2021/12/24 19:00:59] root INFO: algorithm : DB
[2021/12/24 19:00:59] root INFO: freeze_params : True
[2021/12/24 19:00:59] root INFO: model_type : det
[2021/12/24 19:00:59] root INFO: return_all_feats : False
[2021/12/24 19:00:59] root INFO: algorithm : Distillation
[2021/12/24 19:00:59] root INFO: model_type : det
[2021/12/24 19:00:59] root INFO: name : DistillationModel
[2021/12/24 19:00:59] root INFO: Eval :
[2021/12/24 19:00:59] root INFO: dataset :
[2021/12/24 19:00:59] root INFO: data_dir : ./det_data_lesson_demo/
[2021/12/24 19:00:59] root INFO: label_file_list : ['./det_data_lesson_demo/eval.txt']
[2021/12/24 19:00:59] root INFO: name : SimpleDataSet
[2021/12/24 19:00:59] root INFO: transforms :
[2021/12/24 19:00:59] root INFO: DecodeImage :
[2021/12/24 19:00:59] root INFO: channel_first : False
[2021/12/24 19:00:59] root INFO: img_mode : BGR
[2021/12/24 19:00:59] root INFO: DetLabelEncode : None
[2021/12/24 19:00:59] root INFO: DetResizeForTest : None
[2021/12/24 19:00:59] root INFO: NormalizeImage :
[2021/12/24 19:00:59] root INFO: mean : [0.485, 0.456, 0.406]
[2021/12/24 19:00:59] root INFO: order : hwc
[2021/12/24 19:00:59] root INFO: scale : 1./255.
[2021/12/24 19:00:59] root INFO: std : [0.229, 0.224, 0.225]
[2021/12/24 19:00:59] root INFO: ToCHWImage : None
[2021/12/24 19:00:59] root INFO: KeepKeys :
[2021/12/24 19:00:59] root INFO: keep_keys : ['image', 'shape', 'polys', 'ignore_tags']
[2021/12/24 19:00:59] root INFO: loader :
[2021/12/24 19:00:59] root INFO: batch_size_per_card : 1
[2021/12/24 19:00:59] root INFO: drop_last : False
[2021/12/24 19:00:59] root INFO: num_workers : 0
[2021/12/24 19:00:59] root INFO: shuffle : False
[2021/12/24 19:00:59] root INFO: Global :
[2021/12/24 19:00:59] root INFO: cal_metric_during_train : False
[2021/12/24 19:00:59] root INFO: checkpoints : None
[2021/12/24 19:00:59] root INFO: debug : False
[2021/12/24 19:00:59] root INFO: distributed : False
[2021/12/24 19:00:59] root INFO: epoch_num : 1
[2021/12/24 19:00:59] root INFO: eval_batch_step : [3000, 2000]
[2021/12/24 19:00:59] root INFO: infer_img : doc/imgs_en/img_10.jpg
[2021/12/24 19:00:59] root INFO: log_smooth_window : 20
[2021/12/24 19:00:59] root INFO: pretrained_model : None
[2021/12/24 19:00:59] root INFO: print_batch_step : 2
[2021/12/24 19:00:59] root INFO: save_epoch_step : 1200
[2021/12/24 19:00:59] root INFO: save_inference_dir : None
[2021/12/24 19:00:59] root INFO: save_model_dir : ./output/ch_db_mv3/
[2021/12/24 19:00:59] root INFO: save_res_path : ./output/det_db/predicts_db.txt
[2021/12/24 19:00:59] root INFO: use_gpu : True
[2021/12/24 19:00:59] root INFO: use_visualdl : False
[2021/12/24 19:00:59] root INFO: Loss :
[2021/12/24 19:00:59] root INFO: loss_config_list :
[2021/12/24 19:00:59] root INFO: DistillationDilaDBLoss :
[2021/12/24 19:00:59] root INFO: alpha : 5
[2021/12/24 19:00:59] root INFO: balance_loss : True
[2021/12/24 19:00:59] root INFO: beta : 10
[2021/12/24 19:00:59] root INFO: key : maps
[2021/12/24 19:00:59] root INFO: main_loss_type : DiceLoss
[2021/12/24 19:00:59] root INFO: model_name_pairs : [['Student', 'Teacher'], ['Student2', 'Teacher']]
[2021/12/24 19:00:59] root INFO: ohem_ratio : 3
[2021/12/24 19:00:59] root INFO: weight : 1.0
[2021/12/24 19:00:59] root INFO: DistillationDMLLoss :
[2021/12/24 19:00:59] root INFO: key : maps
[2021/12/24 19:00:59] root INFO: maps_name : thrink_maps
[2021/12/24 19:00:59] root INFO: model_name_pairs : ['Student', 'Student2']
[2021/12/24 19:00:59] root INFO: weight : 1.0
[2021/12/24 19:00:59] root INFO: DistillationDBLoss :
[2021/12/24 19:00:59] root INFO: alpha : 5
[2021/12/24 19:00:59] root INFO: balance_loss : True
[2021/12/24 19:00:59] root INFO: beta : 10
[2021/12/24 19:00:59] root INFO: main_loss_type : DiceLoss
[2021/12/24 19:00:59] root INFO: model_name_list : ['Student', 'Student2']
[2021/12/24 19:00:59] root INFO: ohem_ratio : 3
[2021/12/24 19:00:59] root INFO: weight : 1.0
[2021/12/24 19:00:59] root INFO: name : CombinedLoss
[2021/12/24 19:00:59] root INFO: Metric :
[2021/12/24 19:00:59] root INFO: base_metric_name : DetMetric
[2021/12/24 19:00:59] root INFO: key : Student
[2021/12/24 19:00:59] root INFO: main_indicator : hmean
[2021/12/24 19:00:59] root INFO: name : DistillationMetric
[2021/12/24 19:00:59] root INFO: Optimizer :
[2021/12/24 19:00:59] root INFO: beta1 : 0.9
[2021/12/24 19:00:59] root INFO: beta2 : 0.999
[2021/12/24 19:00:59] root INFO: lr :
[2021/12/24 19:00:59] root INFO: learning_rate : 0.00025
[2021/12/24 19:00:59] root INFO: name : Cosine
[2021/12/24 19:00:59] root INFO: warmup_epoch : 2
[2021/12/24 19:00:59] root INFO: name : Adam
[2021/12/24 19:00:59] root INFO: regularizer :
[2021/12/24 19:00:59] root INFO: factor : 0
[2021/12/24 19:00:59] root INFO: name : L2
[2021/12/24 19:00:59] root INFO: PostProcess :
[2021/12/24 19:00:59] root INFO: box_thresh : 0.6
[2021/12/24 19:00:59] root INFO: max_candidates : 1000
[2021/12/24 19:00:59] root INFO: model_name : ['Student', 'Student2', 'Teacher']
[2021/12/24 19:00:59] root INFO: name : DistillationDBPostProcess
[2021/12/24 19:00:59] root INFO: thresh : 0.3
[2021/12/24 19:00:59] root INFO: unclip_ratio : 1.5
[2021/12/24 19:00:59] root INFO: Train :
[2021/12/24 19:00:59] root INFO: dataset :
[2021/12/24 19:00:59] root INFO: data_dir : ./det_data_lesson_demo/
[2021/12/24 19:00:59] root INFO: label_file_list : ['./det_data_lesson_demo/train.txt']
[2021/12/24 19:00:59] root INFO: name : SimpleDataSet
[2021/12/24 19:00:59] root INFO: ratio_list : [1.0]
[2021/12/24 19:00:59] root INFO: transforms :
[2021/12/24 19:00:59] root INFO: DecodeImage :
[2021/12/24 19:00:59] root INFO: channel_first : False
[2021/12/24 19:00:59] root INFO: img_mode : BGR
[2021/12/24 19:00:59] root INFO: DetLabelEncode : None
[2021/12/24 19:00:59] root INFO: CopyPaste : None
[2021/12/24 19:00:59] root INFO: IaaAugment :
[2021/12/24 19:00:59] root INFO: augmenter_args :
[2021/12/24 19:00:59] root INFO: args :
[2021/12/24 19:00:59] root INFO: p : 0.5
[2021/12/24 19:00:59] root INFO: type : Fliplr
[2021/12/24 19:00:59] root INFO: args :
[2021/12/24 19:00:59] root INFO: rotate : [-10, 10]
[2021/12/24 19:00:59] root INFO: type : Affine
[2021/12/24 19:00:59] root INFO: args :
[2021/12/24 19:00:59] root INFO: size : [0.5, 3]
[2021/12/24 19:00:59] root INFO: type : Resize
[2021/12/24 19:00:59] root INFO: EastRandomCropData :
[2021/12/24 19:00:59] root INFO: keep_ratio : True
[2021/12/24 19:00:59] root INFO: max_tries : 50
[2021/12/24 19:00:59] root INFO: size : [960, 960]
[2021/12/24 19:00:59] root INFO: MakeBorderMap :
[2021/12/24 19:00:59] root INFO: shrink_ratio : 0.4
[2021/12/24 19:00:59] root INFO: thresh_max : 0.7
[2021/12/24 19:00:59] root INFO: thresh_min : 0.3
[2021/12/24 19:00:59] root INFO: MakeShrinkMap :
[2021/12/24 19:00:59] root INFO: min_text_size : 8
[2021/12/24 19:00:59] root INFO: shrink_ratio : 0.4
[2021/12/24 19:00:59] root INFO: NormalizeImage :
[2021/12/24 19:00:59] root INFO: mean : [0.485, 0.456, 0.406]
[2021/12/24 19:00:59] root INFO: order : hwc
[2021/12/24 19:00:59] root INFO: scale : 1./255.
[2021/12/24 19:00:59] root INFO: std : [0.229, 0.224, 0.225]
[2021/12/24 19:00:59] root INFO: ToCHWImage : None
[2021/12/24 19:00:59] root INFO: KeepKeys :
[2021/12/24 19:00:59] root INFO: keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']
[2021/12/24 19:00:59] root INFO: loader :
[2021/12/24 19:00:59] root INFO: batch_size_per_card : 8
[2021/12/24 19:00:59] root INFO: drop_last : False
[2021/12/24 19:00:59] root INFO: num_workers : 0
[2021/12/24 19:00:59] root INFO: shuffle : True
[2021/12/24 19:00:59] root INFO: profiler_options : None
[2021/12/24 19:00:59] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)
[2021/12/24 19:00:59] root INFO: Initialize indexs of datasets:['./det_data_lesson_demo/train.txt']
[2021/12/24 19:00:59] root INFO: Initialize indexs of datasets:['./det_data_lesson_demo/eval.txt']
W1224 19:00:59.372548 617 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1224 19:00:59.377442 617 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[2021/12/24 19:01:03] root INFO: train from scratch
[2021/12/24 19:01:03] root INFO: train dataloader has 94 iters
[2021/12/24 19:01:03] root INFO: valid dataloader has 250 iters
[2021/12/24 19:01:03] root INFO: During the training process, after the 3000th iteration, an evaluation is run every 2000 iterations
[2021/12/24 19:01:03] root INFO: Initialize indexs of datasets:['./det_data_lesson_demo/train.txt']
[2021/12/24 19:01:15] root INFO: epoch: [1/1], iter: 2, lr: 0.000001, dila_dbloss_Student_Teacher: 1.989539, dila_dbloss_Student2_Teacher: 1.587950, loss: 21.715214, dml_thrink_maps_0: 0.184671, db_Student_loss_shrink_maps: 4.637174, db_Student_loss_threshold_maps: 3.504537, db_Student_loss_binary_maps: 0.932368, db_Student2_loss_shrink_maps: 4.679275, db_Student2_loss_threshold_maps: 2.964181, db_Student2_loss_binary_maps: 0.945670, reader_cost: 4.72146 s, batch_cost: 5.70054 s, samples: 24, ips: 2.10506
[2021/12/24 19:01:19] root INFO: epoch: [1/1], iter: 4, lr: 0.000003, dila_dbloss_Student_Teacher: 1.989539, dila_dbloss_Student2_Teacher: 1.593589, loss: 21.568047, dml_thrink_maps_0: 0.184671, db_Student_loss_shrink_maps: 4.637174, db_Student_loss_threshold_maps: 3.396354, db_Student_loss_binary_maps: 0.932368, db_Student2_loss_shrink_maps: 4.679275, db_Student2_loss_threshold_maps: 2.964181, db_Student2_loss_binary_maps: 0.945670, reader_cost: 1.43537 s, batch_cost: 2.06478 s, samples: 16, ips: 3.87451
[2021/12/24 19:01:24] root INFO: epoch: [1/1], iter: 6, lr: 0.000004, dila_dbloss_Student_Teacher: 1.989539, dila_dbloss_Student2_Teacher: 1.593589, loss: 21.568047, dml_thrink_maps_0: 0.184671, db_Student_loss_shrink_maps: 4.642411, db_Student_loss_threshold_maps: 3.396354, db_Student_loss_binary_maps: 0.932368, db_Student2_loss_shrink_maps: 4.679275, db_Student2_loss_threshold_maps: 2.964181, db_Student2_loss_binary_maps: 0.945670, reader_cost: 2.00128 s, batch_cost: 2.63305 s, samples: 16, ips: 3.03830
[2021/12/24 19:01:29] root INFO: epoch: [1/1], iter: 8, lr: 0.000005, dila_dbloss_Student_Teacher: 1.988440, dila_dbloss_Student2_Teacher: 1.587950, loss: 21.715214, dml_thrink_maps_0: 0.182628, db_Student_loss_shrink_maps: 4.716407, db_Student_loss_threshold_maps: 3.504537, db_Student_loss_binary_maps: 0.947787, db_Student2_loss_shrink_maps: 4.750618, db_Student2_loss_threshold_maps: 3.014087, db_Student2_loss_binary_maps: 0.958087, reader_cost: 2.01780 s, batch_cost: 2.59098 s, samples: 16, ips: 3.08763
^C
3.1.2 数据增广
数据增广是提升模型泛化能力重要的手段之一,CopyPaste 是一种新颖的数据增强技巧,已经在目标检测和实例分割任务中验证了有效性。利用 CopyPaste,可以合成文本实例来平衡训练图像中的正负样本之间的比例。相比而言,传统图像旋转、随机翻转和随机裁剪是无法做到的。
CopyPaste 主要步骤包括:
- 随机选择两幅训练图像;
- 随机尺度抖动缩放;
- 随机水平翻转;
- 随机选择一幅图像中的目标子集;
- 粘贴在另一幅图像中随机的位置。
这样就比较好地提升了样本丰富度,同时也增加了模型对环境的鲁棒性。如下图所示,通过在左下角的图中裁剪出来的文本,随机旋转缩放之后粘贴到左上角的图像中,进一步丰富了该文本在不同背景下的多样性。
如果希望在模型训练中使用CopyPaste
,只需在Train.transforms
配置字段中添加CopyPaste
即可,如下所示。
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- CopyPaste: # 添加CopyPaste
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
CopyPaste
的具体实现可以参考copy_paste.py。
下面基于icdar2015检测数据集,演示CopyPaste的实际运行过程。
1
2
3
4
5
import os
import sys
os.chdir("/home/aistudio/PaddleOCR/")
!unzip -oq /home/aistudio/data/data46088/icdar2015.zip
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# 参考代码:
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/data/simple_dataset.py
import logging
import random
import numpy as np
from ppocr.data.imaug import create_operators, transform
logger = logging.basicConfig()
# CopyPaste示例的类
class CopyPasteDemo(object):
def __init__(self, ):
self.data_dir = "./icdar2015/text_localization/"
self.label_file_list = "./icdar2015/text_localization/train_icdar2015_label.txt"
self.data_lines = self.get_image_info_list(self.label_file_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
transforms = [
{"DecodeImage": {"img_mode": "BGR", "channel_first": False}},
{"DetLabelEncode": {}},
{"CopyPaste": {"objects_paste_ratio": 1.0}},
]
self.ops = create_operators(transforms)
# 选择一张图像,将其中的内容拷贝到当前图像中
def get_ext_data(self, idx):
ext_data_num = 1
ext_data = []
load_data_ops = self.ops[:2]
next_idx = idx
while len(ext_data) < ext_data_num:
next_idx = (next_idx + 1) % len(self)
file_idx = self.data_idx_order_list[next_idx]
data_line = self.data_lines[file_idx]
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split("\t")
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
continue
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data = transform(data, load_data_ops)
if data is None:
continue
ext_data.append(data)
return ext_data
# 获取图像信息
def get_image_info_list(self, file_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
data_lines.extend(lines)
return data_lines
# 获取DataSet中的一条数据
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split("\t")
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data['ext_data'] = self.get_ext_data(idx)
outs = transform(data, self.ops)
except Exception as e:
print(
"When parsing line {}, error happened with msg: {}".format(
data_line, e))
outs = None
if outs is None:
return
return outs
def __len__(self):
return len(self.data_idx_order_list)
copy_paste_demo = CopyPasteDemo()
idx = 1
data1 = copy_paste_demo[idx]
print(data1.keys())
print(data1["img_path"])
print(data1["ext_data"][0]["img_path"])
1
2
3
dict_keys(['img_path', 'label', 'image', 'ext_data', 'polys', 'texts', 'ignore_tags'])
./icdar2015/text_localization/icdar_c4_train_imgs/img_603.jpg
./icdar2015/text_localization/icdar_c4_train_imgs/img_233.jpg
- 下面2张图是在CopyPaste之前的图像。
1
2
3
4
5
6
7
8
9
10
11
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
img1 = cv2.imread(data1["img_path"])
img2 = cv2.imread(data1["ext_data"][0]["img_path"])
plt.figure(figsize=(10,6))
plt.imshow(img1[:,:,::-1])
plt.show()
plt.figure(figsize=(10,6))
plt.imshow(img2[:,:,::-1])
plt.show()
- 将更新后的标注检测框画出来,如下所示,其中红色框是原始标注信息,蓝色框是经过CopyPaste补充的标注框。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import json
infos = copy_paste_demo.data_lines[idx]
infos = json.loads(infos.decode('utf-8').split("\t")[1])
img3 = data1["image"].copy()
plt.figure(figsize=(15,10))
plt.imshow(img3[:,:,::-1])
# 原始标注信息
for info in infos:
xs, ys = zip(*info["points"])
xs = list(xs)
ys = list(ys)
xs.append(xs[0])
ys.append(ys[0])
plt.plot(xs, ys, "r")
# 新增的标注信息
for poly_idx in range(len(infos), len(data1["polys"])):
poly = data1["polys"][poly_idx]
xs, ys = zip(*poly)
xs = list(xs)
ys = list(ys)
xs.append(xs[0])
ys.append(ys[0])
plt.plot(xs, ys, "b")
plt.show()
3.1.3 文字检测优化小结
PP-OCRv2中,对文字检测模型采用使用知识蒸馏方案以及数据增广策略,增加模型的泛化性能。最终文字检测模型在大小不变的情况下,Hmean从 0.759 提升至 0.795,具体消融实验如下所示。
PP-OCRv2检测模型消融实验
PP-OCRv2中检测效果如下所示。
3.2 文本识别模型优化详解
PP-OCRv2文字识别模型优化过程中,采用骨干网络优化、UDML知识蒸馏策略、CTC loss改进等技巧,最终将识别精度从 66.7% 提升至 74.8%,具体消融实验如下所示。
PP-OCRv2识别模型消融实验
3.2.1 PP-LCNet轻量级骨干网络
百度提出了一种基于 MKLDNN 加速策略的轻量级 CPU 网络,即 PP-LCNet,大幅提高了轻量级模型在图像分类任务上的性能,对于计算机视觉的下游任务,如文本识别、目标检测、语义分割等,有很好的表现。这里需要注意的是,PP-LCNet是针对CPU+MKLDNN这个场景进行定制优化,在分类任务上的速度和精度都远远优于其他模型,因此大家如果有这个使用场景的模型需求的话,也推荐大家去使用。
PP-LCNet 论文地址:PP-LCNet: A Lightweight CPU Convolutional Neural Network
PP-LCNet基于MobileNetV1改进得到,其结构图如下所示。
相比于MobileNetV1,PP-LCNet中融合了MobileNetV3结构中激活函数、头部结构、SE模块等策略优化技巧,同时分析了最后阶段卷积层的卷积核大小,最终该模型在保证速度优势的基础上,精度大幅超越MobileNet、GhostNet等轻量级模型。
具体地,PP-LCNet中共涉及到下面4个优化点。
- 除了 SE 模块,网络中所有的 relu 激活函数替换为 h-swish,精度提升1%-2%
- PP-LCNet 第五阶段,DW 的 kernel size 变为5x5,精度提升0.5%-1%
- PP-LCNet 第五阶段的最后两个 DepthSepConv block 添加 SE 模块, 精度提升0.5%-1%
- GAP 后添加 1280 维的 FC 层,增加特征表达能力,精度提升2%-3%
在ImageNet1k数据集上,PP-LCNet相比于其他目前比较常用的轻量级分类模型,Top1-Acc 与预测耗时如下图所示。可以看出,预测耗时和精度都是要更优的。
通过下面这种方式,便可以快速完成PP-LCNet识别模型的定义。
1
2
3
4
5
6
7
8
9
10
# 参考代码
# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/backbones/rec_mv1_enhance.py
from ppocr.modeling.backbones.rec_mv1_enhance import MobileNetV1Enhance
x = paddle.rand([1, 3, 23, 320])
model = MobileNetV1Enhance(scale=0.5)
y = model(x)
print(y.shape)
[1, 512, 1, 80]
3.2.2 U-DML 知识蒸馏策略
对于标准的 DML 策略,蒸馏的损失函数仅包括最后输出层监督,然而对于 2 个结构完全相同的模型来说,对于完全相同的输入,它们的中间特征输出期望也完全相同,因此在最后输出层监督的监督上,可以进一步添加中间输出的特征图的监督信号,作为损失函数,即 PP-OCRv2 中的 U-DML (Unified-Deep Mutual Learning) 知识蒸馏方法。
U-DML 知识蒸馏的算法流程图如下所示。 Teacher 模型与 Student 模型的网络结构完全相同,初始化参数不同,此外,在新增在标准的 DML 知识蒸馏的基础上,新增引入了对于 Feature Map 的监督机制,新增 Feature Loss。
在训练的过程中,总共包含 3 种 loss: GT loss,DML loss,Feature loss。
- GT loss
文本识别任务使用的模型结构是 CRNN,因此使用 CTC loss 作为 GT loss, GT loss 计算方法如下所示。
%20%2B%20CTC(T%7Bhout%7D%2C%20gt)%20%0A#card=math&code=Loss%7Bctc%7D%20%3D%20CTC%28S%7Bhout%7D%2C%20gt%29%20%2B%20CTC%28T%7Bhout%7D%2C%20gt%29%20%0A&id=fkiUo)
- DML loss
DML loss 计算方法如下,这里 Teacher 模型与 Student 模型互相计算 KL 散度,最终 DML loss具有对称性。
%20%2B%20KL(T%7Bpout%7D%20%7C%7C%20S%7Bpout%7D)%7D%7B2%7D%20%0A#card=math&code=Loss%7Bdml%7D%20%3D%20%5Cfrac%7BKL%28S%7Bpout%7D%20%7C%7C%20T%7Bpout%7D%29%20%2B%20KL%28T%7Bpout%7D%20%7C%7C%20S_%7Bpout%7D%29%7D%7B2%7D%20%0A&id=WYRUo)
- Feature loss
Feature loss 使用的是 L2 loss,具体计算方法如下所示。
%20%0A#card=math&code=Loss%7Bfeat%7D%20%3D%20L2%28S%7Bbout%7D%2C%20T_%7Bbout%7D%29%20%0A&id=ahDQ9)
最终,训练过程中的 loss function 计算方法如下所示。
此外,在训练过程中通过增加迭代次数,在 Head 部分添加 FC 层等 trick,平衡模型的特征编码与解码的能力,进一步提升了模型效果。
配置文件在ch_PP-OCRv2_rec_distillation.yml。
Architecture:
model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型类别都与
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
algorithm: Distillation # 算法名称
Models: # 模型,包含子网络的配置信息
Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
pretrained: # 该子网络是否需要加载预训练模型
freeze_params: false # 是否需要固定参数
return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
model_type: *model_type # 模型类别
algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96 # Head解码过程中穿插一层
fc_decay: 0.00002
Student: # 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数
pretrained: # 下面的组网参数同上
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
当然,这里如果希望添加更多的子网络进行训练,也可以按照Student
与Teacher
的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么Architecture
可以写为如下格式。
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
最终该模型训练时,包含3个子网络:Teacher
, Student
, Student2
。
蒸馏模型DistillationModel
类的具体实现代码可以参考distillation_model.py。
最终模型forward
输出为一个字典,key为所有的子网络名称,例如这里为Student
与Teacher
,value为对应子网络的输出,可以为Tensor
(只返回该网络的最后一层)和dict
(也返回了中间的特征信息)。
在识别任务中,为了添加更多损失函数,保证蒸馏方法的可扩展性,将每个子网络的输出保存为dict
,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果均为dict
,key包含backbone_out
,neck_out
, head_out
,value
为对应模块的tensor,最终对于上述配置文件,DistillationModel
的输出格式如下。
1
2
3
4
5
6
7
8
9
10
11
12
{
"Teacher": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
},
"Student": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
}
}
知识蒸馏任务中,损失函数配置如下所示。
Loss:
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
key: head_out # 取子网络输出dict中,该key对应的tensor
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
weight: 1.0 # 权重
act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
- ["Student", "Teacher"]
key: head_out # 取子网络输出dict中,该key对应的tensor
- DistillationDistanceLoss: # 蒸馏的距离损失函数
weight: 1.0 # 权重
mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1
model_name_pairs: # 用于计算distance loss的子网络名称对
- ["Student", "Teacher"]
key: backbone_out # 取子网络输出dict中,该key对应的tensor
上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。
以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。
Student
和Teacher
的最终输出(head_out
)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与gt的loss。Student
和Teacher
的最终输出(head_out
)之间的DML loss,权重为1。Student
和Teacher
的骨干网络输出(backbone_out
)之间的l2 loss,权重为1。
CombinedLoss
类实现如下。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class CombinedLoss(nn.Layer):
"""
CombinedLoss:
a combionation of loss function
"""
def __init__(self, loss_config_list=None):
super().__init__()
self.loss_func = []
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
for config in loss_config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
def forward(self, input, batch, **kargs):
loss_dict = {}
loss_all = 0.
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx]
loss = {key: loss[key] * weight for key in loss}
if "loss" in loss:
loss_all += loss["loss"]
else:
loss_all += paddle.add_n(list(loss.values()))
loss_dict.update(loss)
loss_dict["loss"] = loss_all
return loss_dict
关于CombinedLoss
更加具体的实现可以参考: combined_loss.py。关于DistillationCTCLoss
等蒸馏损失函数更加具体的实现可以参考distillation_loss.py。
对于上面3个模型的蒸馏,Loss字段也需要相应修改,同时考虑3个子网络之间的损失,如下所示。
Loss:
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_list: ["Student", "Student2", "Teacher"] # 对于蒸馏模型的预测结果,提取这三个子网络的输出,与gt计算CTC loss
key: head_out # 取子网络输出dict中,该key对应的tensor
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
weight: 1.0 # 权重
act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
- ["Student", "Student2"]
key: head_out # 取子网络输出dict中,该key对应的tensor
- DistillationDistanceLoss: # 蒸馏的距离损失函数
weight: 1.0 # 权重
mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1
model_name_pairs: # 用于计算distance loss的子网络名称对
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
- ["Student", "Student2"]
key: backbone_out # 取子网络输出dict中,该key对应的tensor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 下载数据
!wget -nc https://paddleocr.bj.bcebos.com/dataset/rec_data_lesson_demo.tar && tar -xf rec_data_lesson_demo.tar && rm rec_data_lesson_demo.tar
# # 下载预训练模型
!wget -nc https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar && tar -xf ch_PP-OCRv2_rec_train.tar && rm ch_PP-OCRv2_rec_train.tar
!python tools/train.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml \
-o Train.dataset.data_dir="./rec_data_lesson_demo/" \
Train.dataset.label_file_list=["./rec_data_lesson_demo/train.txt"] \
Train.loader.num_workers=0 \
Train.loader.batch_size_per_card=64 \
Eval.dataset.data_dir="./rec_data_lesson_demo/" \
Eval.dataset.label_file_list=["./rec_data_lesson_demo/val.txt"] \
Eval.loader.num_workers=0 \
Optimizer.lr.values=[0.0001,0.00001] \
Global.epoch_num=1 \
Global.pretrained_model="./ch_PP-OCRv2_rec_train/best_accuracy"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
--2021-12-24 19:27:37-- https://paddleocr.bj.bcebos.com/dataset/rec_data_lesson_demo.tar
Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 699098618 (667M) [application/x-tar]
Saving to: ‘rec_data_lesson_demo.tar’
rec_data_lesson_dem 100%[===================>] 666.71M 71.0MB/s in 10s
2021-12-24 19:27:48 (63.8 MB/s) - ‘rec_data_lesson_demo.tar’ saved [699098618/699098618]
--2021-12-24 19:27:55-- https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.229, 182.61.200.195, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.229|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 77350400 (74M) [application/x-tar]
Saving to: ‘ch_PP-OCRv2_rec_train.tar’
ch_PP-OCRv2_rec_tra 100%[===================>] 73.77M 19.9MB/s in 4.7s
2021-12-24 19:28:00 (15.8 MB/s) - ‘ch_PP-OCRv2_rec_train.tar’ saved [77350400/77350400]
[2021/12/24 19:28:02] root INFO: Architecture :
[2021/12/24 19:28:02] root INFO: Models :
[2021/12/24 19:28:02] root INFO: Student :
[2021/12/24 19:28:02] root INFO: Backbone :
[2021/12/24 19:28:02] root INFO: name : MobileNetV1Enhance
[2021/12/24 19:28:02] root INFO: scale : 0.5
[2021/12/24 19:28:02] root INFO: Head :
[2021/12/24 19:28:02] root INFO: fc_decay : 2e-05
[2021/12/24 19:28:02] root INFO: mid_channels : 96
[2021/12/24 19:28:02] root INFO: name : CTCHead
[2021/12/24 19:28:02] root INFO: Neck :
[2021/12/24 19:28:02] root INFO: encoder_type : rnn
[2021/12/24 19:28:02] root INFO: hidden_size : 64
[2021/12/24 19:28:02] root INFO: name : SequenceEncoder
[2021/12/24 19:28:02] root INFO: Transform : None
[2021/12/24 19:28:02] root INFO: algorithm : CRNN
[2021/12/24 19:28:02] root INFO: freeze_params : False
[2021/12/24 19:28:02] root INFO: model_type : rec
[2021/12/24 19:28:02] root INFO: pretrained : None
[2021/12/24 19:28:02] root INFO: return_all_feats : True
[2021/12/24 19:28:02] root INFO: Teacher :
[2021/12/24 19:28:02] root INFO: Backbone :
[2021/12/24 19:28:02] root INFO: name : MobileNetV1Enhance
[2021/12/24 19:28:02] root INFO: scale : 0.5
[2021/12/24 19:28:02] root INFO: Head :
[2021/12/24 19:28:02] root INFO: fc_decay : 2e-05
[2021/12/24 19:28:02] root INFO: mid_channels : 96
[2021/12/24 19:28:02] root INFO: name : CTCHead
[2021/12/24 19:28:02] root INFO: Neck :
[2021/12/24 19:28:02] root INFO: encoder_type : rnn
[2021/12/24 19:28:02] root INFO: hidden_size : 64
[2021/12/24 19:28:02] root INFO: name : SequenceEncoder
[2021/12/24 19:28:02] root INFO: Transform : None
[2021/12/24 19:28:02] root INFO: algorithm : CRNN
[2021/12/24 19:28:02] root INFO: freeze_params : False
[2021/12/24 19:28:02] root INFO: model_type : rec
[2021/12/24 19:28:02] root INFO: pretrained : None
[2021/12/24 19:28:02] root INFO: return_all_feats : True
[2021/12/24 19:28:02] root INFO: algorithm : Distillation
[2021/12/24 19:28:02] root INFO: model_type : rec
[2021/12/24 19:28:02] root INFO: name : DistillationModel
[2021/12/24 19:28:02] root INFO: Eval :
[2021/12/24 19:28:02] root INFO: dataset :
[2021/12/24 19:28:02] root INFO: data_dir : ./rec_data_lesson_demo/
[2021/12/24 19:28:02] root INFO: label_file_list : ['./rec_data_lesson_demo/val.txt']
[2021/12/24 19:28:02] root INFO: name : SimpleDataSet
[2021/12/24 19:28:02] root INFO: transforms :
[2021/12/24 19:28:02] root INFO: DecodeImage :
[2021/12/24 19:28:02] root INFO: channel_first : False
[2021/12/24 19:28:02] root INFO: img_mode : BGR
[2021/12/24 19:28:02] root INFO: CTCLabelEncode : None
[2021/12/24 19:28:02] root INFO: RecResizeImg :
[2021/12/24 19:28:02] root INFO: image_shape : [3, 32, 320]
[2021/12/24 19:28:02] root INFO: KeepKeys :
[2021/12/24 19:28:02] root INFO: keep_keys : ['image', 'label', 'length']
[2021/12/24 19:28:02] root INFO: loader :
[2021/12/24 19:28:02] root INFO: batch_size_per_card : 128
[2021/12/24 19:28:02] root INFO: drop_last : False
[2021/12/24 19:28:02] root INFO: num_workers : 0
[2021/12/24 19:28:02] root INFO: shuffle : False
[2021/12/24 19:28:02] root INFO: Global :
[2021/12/24 19:28:02] root INFO: cal_metric_during_train : True
[2021/12/24 19:28:02] root INFO: character_dict_path : ppocr/utils/ppocr_keys_v1.txt
[2021/12/24 19:28:02] root INFO: checkpoints : None
[2021/12/24 19:28:02] root INFO: debug : False
[2021/12/24 19:28:02] root INFO: distributed : False
[2021/12/24 19:28:02] root INFO: epoch_num : 1
[2021/12/24 19:28:02] root INFO: eval_batch_step : [0, 2000]
[2021/12/24 19:28:02] root INFO: infer_img : doc/imgs_words/ch/word_1.jpg
[2021/12/24 19:28:02] root INFO: infer_mode : False
[2021/12/24 19:28:02] root INFO: log_smooth_window : 20
[2021/12/24 19:28:02] root INFO: max_text_length : 25
[2021/12/24 19:28:02] root INFO: pretrained_model : ./ch_PP-OCRv2_rec_train/best_accuracy
[2021/12/24 19:28:02] root INFO: print_batch_step : 10
[2021/12/24 19:28:02] root INFO: save_epoch_step : 3
[2021/12/24 19:28:02] root INFO: save_inference_dir : None
[2021/12/24 19:28:02] root INFO: save_model_dir : ./output/rec_pp-OCRv2_distillation
[2021/12/24 19:28:02] root INFO: save_res_path : ./output/rec/predicts_pp-OCRv2_distillation.txt
[2021/12/24 19:28:02] root INFO: use_gpu : True
[2021/12/24 19:28:02] root INFO: use_space_char : True
[2021/12/24 19:28:02] root INFO: use_visualdl : False
[2021/12/24 19:28:02] root INFO: Loss :
[2021/12/24 19:28:02] root INFO: loss_config_list :
[2021/12/24 19:28:02] root INFO: DistillationCTCLoss :
[2021/12/24 19:28:02] root INFO: key : head_out
[2021/12/24 19:28:02] root INFO: model_name_list : ['Student', 'Teacher']
[2021/12/24 19:28:02] root INFO: weight : 1.0
[2021/12/24 19:28:02] root INFO: DistillationDMLLoss :
[2021/12/24 19:28:02] root INFO: act : softmax
[2021/12/24 19:28:02] root INFO: key : head_out
[2021/12/24 19:28:02] root INFO: model_name_pairs : [['Student', 'Teacher']]
[2021/12/24 19:28:02] root INFO: use_log : True
[2021/12/24 19:28:02] root INFO: weight : 1.0
[2021/12/24 19:28:02] root INFO: DistillationDistanceLoss :
[2021/12/24 19:28:02] root INFO: key : backbone_out
[2021/12/24 19:28:02] root INFO: mode : l2
[2021/12/24 19:28:02] root INFO: model_name_pairs : [['Student', 'Teacher']]
[2021/12/24 19:28:02] root INFO: weight : 1.0
[2021/12/24 19:28:02] root INFO: name : CombinedLoss
[2021/12/24 19:28:02] root INFO: Metric :
[2021/12/24 19:28:02] root INFO: base_metric_name : RecMetric
[2021/12/24 19:28:02] root INFO: key : Student
[2021/12/24 19:28:02] root INFO: main_indicator : acc
[2021/12/24 19:28:02] root INFO: name : DistillationMetric
[2021/12/24 19:28:02] root INFO: Optimizer :
[2021/12/24 19:28:02] root INFO: beta1 : 0.9
[2021/12/24 19:28:02] root INFO: beta2 : 0.999
[2021/12/24 19:28:02] root INFO: lr :
[2021/12/24 19:28:02] root INFO: decay_epochs : [700, 800]
[2021/12/24 19:28:02] root INFO: name : Piecewise
[2021/12/24 19:28:02] root INFO: values : [0.0001, 1e-05]
[2021/12/24 19:28:02] root INFO: warmup_epoch : 5
[2021/12/24 19:28:02] root INFO: name : Adam
[2021/12/24 19:28:02] root INFO: regularizer :
[2021/12/24 19:28:02] root INFO: factor : 2e-05
[2021/12/24 19:28:02] root INFO: name : L2
[2021/12/24 19:28:02] root INFO: PostProcess :
[2021/12/24 19:28:02] root INFO: key : head_out
[2021/12/24 19:28:02] root INFO: model_name : ['Student', 'Teacher']
[2021/12/24 19:28:02] root INFO: name : DistillationCTCLabelDecode
[2021/12/24 19:28:02] root INFO: Train :
[2021/12/24 19:28:02] root INFO: dataset :
[2021/12/24 19:28:02] root INFO: data_dir : ./rec_data_lesson_demo/
[2021/12/24 19:28:02] root INFO: label_file_list : ['./rec_data_lesson_demo/train.txt']
[2021/12/24 19:28:02] root INFO: name : SimpleDataSet
[2021/12/24 19:28:02] root INFO: transforms :
[2021/12/24 19:28:02] root INFO: DecodeImage :
[2021/12/24 19:28:02] root INFO: channel_first : False
[2021/12/24 19:28:02] root INFO: img_mode : BGR
[2021/12/24 19:28:02] root INFO: RecAug : None
[2021/12/24 19:28:02] root INFO: CTCLabelEncode : None
[2021/12/24 19:28:02] root INFO: RecResizeImg :
[2021/12/24 19:28:02] root INFO: image_shape : [3, 32, 320]
[2021/12/24 19:28:02] root INFO: KeepKeys :
[2021/12/24 19:28:02] root INFO: keep_keys : ['image', 'label', 'length']
[2021/12/24 19:28:02] root INFO: loader :
[2021/12/24 19:28:02] root INFO: batch_size_per_card : 64
[2021/12/24 19:28:02] root INFO: drop_last : True
[2021/12/24 19:28:02] root INFO: num_sections : 1
[2021/12/24 19:28:02] root INFO: num_workers : 0
[2021/12/24 19:28:02] root INFO: shuffle : True
[2021/12/24 19:28:02] root INFO: profiler_options : None
[2021/12/24 19:28:02] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)
[2021/12/24 19:28:02] root INFO: Initialize indexs of datasets:['./rec_data_lesson_demo/train.txt']
[2021/12/24 19:28:02] root INFO: Initialize indexs of datasets:['./rec_data_lesson_demo/val.txt']
W1224 19:28:02.638623 1690 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1224 19:28:02.643308 1690 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[2021/12/24 19:28:07] root INFO: load pretrain successful from ./ch_PP-OCRv2_rec_train/best_accuracy
[2021/12/24 19:28:07] root INFO: train dataloader has 1562 iters
[2021/12/24 19:28:07] root INFO: valid dataloader has 24 iters
[2021/12/24 19:28:07] root INFO: During the training process, after the 0th iteration, an evaluation is run every 2000 iterations
[2021/12/24 19:28:07] root INFO: Initialize indexs of datasets:['./rec_data_lesson_demo/train.txt']
[2021/12/24 19:28:18] root INFO: epoch: [1/1], iter: 10, lr: 0.000000, loss_ctc_Student_0: 7.655914, loss_ctc_Teacher_1: 8.093642, dml_0: 7.908890, loss: 25.157150, loss_distance_l2_Student_Teacher_0: 0.025226, acc: 0.531242, norm_edit_dis: 0.727125, Teacher_acc: 0.578116, Teacher_norm_edit_dis: 0.736437, reader_cost: 0.37696 s, batch_cost: 0.66512 s, samples: 704, ips: 105.84503
^C
3.2.3 Enhanced CTC loss 改进
中文 OCR 任务经常遇到的识别难点是相似字符数太多,容易误识。借鉴 Metric Learning 中的想法,引入 Center loss,进一步增大类间距离,核心公式如下所示。
这里 表示时间步长 处的标签, 表示标签 对应的 center。
Enhance CTC 中,center 的初始化对结果也有较大影响,在 PP-OCRv2 中,center 初始化的具体步骤如下所示。
- 基于标准的 CTC loss,训练一个网络;
- 提取出训练集合中识别正确的图像集合,记为 G ;
- 将 G 中的图片依次输入网络, 提取head输出时序特征的 和 的对应关系,其中 计算方式如下:
%20%0A#card=math&code=y%7Bt%7D%20%3D%20argmax%28W%20%2A%20x%7Bt%7D%29%20%0A&id=Mh7AT)
- 将相同 对应的 聚合在一起,取其平均值,作为初始 center。
首先需要基于configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml训练一个基础网络
更多关于Center loss的训练步骤可以参考:Enhanced CTC Loss使用文档
最后,使用configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml进行训练,命令如下所示。
python tools/train.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml
主要改进点为Loss
字段,相比于标准的CTCLoss
,添加了CenterLoss
。配置类别数、特征维度、center路径即可。
Loss:
name: CombinedLoss
loss_config_list:
- CTCLoss:
use_focal_loss: false
weight: 1.0
- CenterLoss:
weight: 0.05
num_classes: 6625
feat_dim: 96
center_file_path: "./train_center.pkl"
3.2.4 文本识别优化小结
PP-OCRv2文字识别模型优化过程中,对模型从骨干网络、损失函数等角度进行改进,并引入知识蒸馏的训练方法,最终将识别精度从 66.7% 提升至 74.8%,具体消融实验如下所示。
PP-OCRv2识别模型消融实验
在PP-OCRv2文字检测的基础上,识别模型的实验效果如下所示。
4. 总结
本章主要介绍PP-OCR以及PP-OCRv2的优化策略。
PP-OCR从骨干网络、学习率策略、数据增广、模型裁剪量化等方面,共使用了19个对策略,对模型进行优化瘦身,最终打造了面向服务器端的PP-OCR server系统以及面向移动端的PP-OCR mobile系统。
相比于PP-OCR, PP-OCRv2 在骨干网络、数据增广、损失函数这三个方面进行进一步优化,解决端侧预测效率较差、背景复杂以及相似字符的误识等问题,同时引入了知识蒸馏训练策略,进一步提升模型精度,最终打造了精度、速度远超PP-OCR的文字检测与识别系统。
5. 作业
具体内容见课程结业必修中的优化策略客观题
以及优化策略实战题
部分(视频中提及)。